diff --git a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
index 005f6ad..15f6869 100644
--- a/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
+++ b/tests/cts/net/src/android/net/cts/NetworkStatsManagerTest.java
@@ -96,8 +96,10 @@
 import java.net.UnknownHostException;
 import java.text.MessageFormat;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.TimeUnit;
 
@@ -712,27 +714,56 @@
         }
     }
 
-    class QueryResult {
-        public final int tag;
-        public final int state;
-        public final long total;
+    class QueryResults {
+        private static class QueryKey {
+            private final int mTag;
+            private final int mState;
 
-        QueryResult(int tag, int state, NetworkStats stats) {
-            this.tag = tag;
-            this.state = state;
-            total = getTotalAndAssertNotEmpty(stats, tag, state);
+            QueryKey(int tag, int state) {
+                this.mTag = tag;
+                this.mState = state;
+            }
+
+            @Override
+            public boolean equals(Object o) {
+                if (this == o) return true;
+                if (!(o instanceof QueryKey)) return false;
+
+                QueryKey queryKey = (QueryKey) o;
+                return mTag == queryKey.mTag && mState == queryKey.mState;
+            }
+
+            @Override
+            public int hashCode() {
+                return Objects.hash(mTag, mState);
+            }
+
+            @Override
+            public String toString() {
+                return String.format("QueryKey(tag=%s, state=%s)", tagToString(mTag),
+                        stateToString(mState));
+            }
         }
 
-        public String toString() {
-            return String.format("QueryResult(tag=%s state=%s total=%d)",
-                    tagToString(tag), stateToString(state), total);
+        private final HashMap<QueryKey, Long> mSnapshot = new HashMap<>();
+
+        public long get(int tag, int state) {
+            // Expect all results are stored before access.
+            return Objects.requireNonNull(mSnapshot.get(new QueryKey(tag, state)));
+        }
+
+        public void put(int tag, int state, long total) {
+            mSnapshot.put(new QueryKey(tag, state), total);
         }
     }
 
-    private NetworkStats getNetworkStatsForTagState(int i, int tag, int state) {
-        return mNsm.queryDetailsForUidTagState(
+    private long getTotalForTagState(int i, int tag, int state, boolean assertNotEmpty) {
+        final NetworkStats stats = mNsm.queryDetailsForUidTagState(
                 mNetworkInterfacesToTest[i].getNetworkType(), getSubscriberId(i),
                 mStartTime, mEndTime, Process.myUid(), tag, state);
+        final long total = getTotal(stats, tag, state, assertNotEmpty);
+        stats.close();
+        return total;
     }
 
     private void assertWithinPercentage(String msg, long expected, long actual, int percentage) {
@@ -743,21 +774,12 @@
         assertTrue(msg, upperBound >= actual);
     }
 
-    private void assertAlmostNoUnexpectedTraffic(NetworkStats result, int expectedTag,
+    private void assertAlmostNoUnexpectedTraffic(long total, int expectedTag,
             int expectedState, long maxUnexpected) {
-        long total = 0;
-        NetworkStats.Bucket bucket = new NetworkStats.Bucket();
-        while (result.hasNextBucket()) {
-            assertTrue(result.getNextBucket(bucket));
-            total += bucket.getRxBytes() + bucket.getTxBytes();
-        }
         if (total <= maxUnexpected) return;
 
-        fail(String.format("More than %d bytes of traffic when querying for "
-                + "tag %s state %s. Last bucket: uid=%d tag=%s state=%s bytes=%d/%d",
-                maxUnexpected, tagToString(expectedTag), stateToString(expectedState),
-                bucket.getUid(), tagToString(bucket.getTag()), stateToString(bucket.getState()),
-                bucket.getRxBytes(), bucket.getTxBytes()));
+        fail(String.format("More than %d bytes of traffic when querying for tag %s state %s.",
+                maxUnexpected, tagToString(expectedTag), stateToString(expectedState)));
     }
 
     @ConnectivityDiagnosticsCollector.CollectTcpdumpOnFailure
@@ -770,66 +792,59 @@
             // Relatively large tolerance to accommodate for history bucket size.
             requestNetworkAndGenerateTraffic(i, LONG_TOLERANCE);
             setAppOpsMode(AppOpsManager.OPSTR_GET_USAGE_STATS, "allow");
-            NetworkStats result = null;
-            try {
-                int currentState = isInForeground() ? STATE_FOREGROUND : STATE_DEFAULT;
-                int otherState = (currentState == STATE_DEFAULT) ? STATE_FOREGROUND : STATE_DEFAULT;
 
-                int[] tagsWithTraffic = {NETWORK_TAG, TAG_NONE};
-                int[] statesWithTraffic = {currentState, STATE_ALL};
-                ArrayList<QueryResult> resultsWithTraffic = new ArrayList<>();
+            int currentState = isInForeground() ? STATE_FOREGROUND : STATE_DEFAULT;
+            int otherState = (currentState == STATE_DEFAULT) ? STATE_FOREGROUND : STATE_DEFAULT;
 
-                int[] statesWithNoTraffic = {otherState};
-                int[] tagsWithNoTraffic = {NETWORK_TAG + 1};
-                ArrayList<QueryResult> resultsWithNoTraffic = new ArrayList<>();
+            final List<Integer> statesWithTraffic = List.of(currentState, STATE_ALL);
+            final List<Integer> statesWithNoTraffic = List.of(otherState);
+            final ArrayList<Integer> allStates = new ArrayList<>();
+            allStates.addAll(statesWithTraffic);
+            allStates.addAll(statesWithNoTraffic);
 
-                // Expect to see traffic when querying for any combination of a tag in
-                // tagsWithTraffic and a state in statesWithTraffic.
-                for (int tag : tagsWithTraffic) {
-                    for (int state : statesWithTraffic) {
-                        result = getNetworkStatsForTagState(i, tag, state);
-                        resultsWithTraffic.add(new QueryResult(tag, state, result));
-                        result.close();
-                        result = null;
-                    }
-                }
+            final List<Integer> tagsWithTraffic = List.of(NETWORK_TAG, TAG_NONE);
+            final List<Integer> tagsWithNoTraffic = List.of(NETWORK_TAG + 1);
+            final ArrayList<Integer> allTags = new ArrayList<>();
+            allTags.addAll(tagsWithTraffic);
+            allTags.addAll(tagsWithNoTraffic);
 
-                // Expect that the results are within a few percentage points of each other.
-                // This is ensures that FIN retransmits after the transfer is complete don't cause
-                // the test to be flaky. The test URL currently returns just over 100k so this
-                // should not be too noisy. It also ensures that the traffic sent by the test
-                // harness, which is untagged, won't cause a failure.
-                long firstTotal = resultsWithTraffic.get(0).total;
-                for (QueryResult queryResult : resultsWithTraffic) {
-                    assertWithinPercentage(queryResult + "", firstTotal, queryResult.total, 16);
-                }
+            QueryResults results = new QueryResults();
 
-                // Expect to see no traffic when querying for any tag in tagsWithNoTraffic or any
-                // state in statesWithNoTraffic.
-                for (int tag : tagsWithNoTraffic) {
-                    for (int state : statesWithTraffic) {
-                        result = getNetworkStatsForTagState(i, tag, state);
-                        assertAlmostNoUnexpectedTraffic(result, tag, state, firstTotal / 100);
-                        result.close();
-                        result = null;
-                    }
-                }
-                for (int tag : tagsWithTraffic) {
-                    for (int state : statesWithNoTraffic) {
-                        result = getNetworkStatsForTagState(i, tag, state);
-                        assertAlmostNoUnexpectedTraffic(result, tag, state, firstTotal / 100);
-                        result.close();
-                        result = null;
-                    }
-                }
-            } finally {
-                if (result != null) {
-                    result.close();
+            // Collect results for all combinations of tags and states.
+            for (int tag : allTags) {
+                for (int state : allStates) {
+                    final boolean assertNotEmpty = tagsWithTraffic.contains(tag)
+                            && statesWithTraffic.contains(state);
+                    final long total = getTotalForTagState(i, tag, state, assertNotEmpty);
+                    results.put(tag, state, total);
                 }
             }
+
+            // Expect that the results are within a few percentage points of each other.
+            // This is ensures that FIN retransmits after the transfer is complete don't cause
+            // the test to be flaky. The test URL currently returns just over 100k so this
+            // should not be too noisy. It also ensures that the traffic sent by the test
+            // harness, which is untagged, won't cause a failure.
+            long totalOfNetworkTagAndCurrentState = results.get(NETWORK_TAG, currentState);
+            for (int tag : allTags) {
+                for (int state : allStates) {
+                    final long result = results.get(tag, state);
+                    final String queryKeyStr = new QueryResults.QueryKey(tag, state).toString();
+                    if (tagsWithTraffic.contains(tag) && statesWithTraffic.contains(state)) {
+                        assertWithinPercentage(queryKeyStr,
+                                totalOfNetworkTagAndCurrentState, result, 16);
+                    } else {
+                        // Expect to see no traffic when querying for any combination with tag
+                        // in tagsWithNoTraffic or any state in statesWithNoTraffic.
+                        assertAlmostNoUnexpectedTraffic(result, tag, state,
+                                totalOfNetworkTagAndCurrentState / 100);
+                    }
+                }
+            }
+
             setAppOpsMode(AppOpsManager.OPSTR_GET_USAGE_STATS, "deny");
             try {
-                result = mNsm.queryDetailsForUidTag(
+                mNsm.queryDetailsForUidTag(
                         mNetworkInterfacesToTest[i].getNetworkType(), getSubscriberId(i),
                         mStartTime, mEndTime, Process.myUid(), NETWORK_TAG);
                 fail("negative testUidDetails fails: no exception thrown.");
@@ -902,7 +917,7 @@
         }
     }
 
-    private String tagToString(Integer tag) {
+    private static String tagToString(Integer tag) {
         if (tag == null) return "null";
         switch (tag) {
             case TAG_NONE:
@@ -912,7 +927,7 @@
         }
     }
 
-    private String stateToString(Integer state) {
+    private static String stateToString(Integer state) {
         if (state == null) return "null";
         switch (state) {
             case STATE_ALL:
@@ -925,8 +940,8 @@
         throw new IllegalArgumentException("Unknown state " + state);
     }
 
-    private long getTotalAndAssertNotEmpty(NetworkStats result, Integer expectedTag,
-            Integer expectedState) {
+    private long getTotal(NetworkStats result, Integer expectedTag,
+            Integer expectedState, boolean assertNotEmpty) {
         assertTrue(result != null);
         NetworkStats.Bucket bucket = new NetworkStats.Bucket();
         long totalTxPackets = 0;
@@ -951,16 +966,18 @@
         assertFalse(result.getNextBucket(bucket));
         String msg = String.format("uid %d tag %s state %s",
                 Process.myUid(), tagToString(expectedTag), stateToString(expectedState));
-        assertTrue("No Rx bytes usage for " + msg, totalRxBytes > 0);
-        assertTrue("No Rx packets usage for " + msg, totalRxPackets > 0);
-        assertTrue("No Tx bytes usage for " + msg, totalTxBytes > 0);
-        assertTrue("No Tx packets usage for " + msg, totalTxPackets > 0);
+        if (assertNotEmpty) {
+            assertTrue("No Rx bytes usage for " + msg, totalRxBytes > 0);
+            assertTrue("No Rx packets usage for " + msg, totalRxPackets > 0);
+            assertTrue("No Tx bytes usage for " + msg, totalTxBytes > 0);
+            assertTrue("No Tx packets usage for " + msg, totalTxPackets > 0);
+        }
 
         return totalRxBytes + totalTxBytes;
     }
 
     private long getTotalAndAssertNotEmpty(NetworkStats result) {
-        return getTotalAndAssertNotEmpty(result, null, STATE_ALL);
+        return getTotal(result, null, STATE_ALL, true /*assertEmpty*/);
     }
 
     private void assertTimestamps(final NetworkStats.Bucket bucket) {
