Merge "Update TetheringTests for new connectivity shims"
diff --git a/Tethering/src/com/android/networkstack/tethering/OffloadController.java b/Tethering/src/com/android/networkstack/tethering/OffloadController.java
index 44e3916..beb1821 100644
--- a/Tethering/src/com/android/networkstack/tethering/OffloadController.java
+++ b/Tethering/src/com/android/networkstack/tethering/OffloadController.java
@@ -26,6 +26,8 @@
 import static android.net.netstats.provider.NetworkStatsProvider.QUOTA_UNLIMITED;
 import static android.provider.Settings.Global.TETHER_OFFLOAD_DISABLED;
 
+import static com.android.networkstack.tethering.OffloadHardwareInterface.OFFLOAD_HAL_VERSION_1_0;
+import static com.android.networkstack.tethering.OffloadHardwareInterface.OFFLOAD_HAL_VERSION_1_1;
 import static com.android.networkstack.tethering.OffloadHardwareInterface.OFFLOAD_HAL_VERSION_NONE;
 import static com.android.networkstack.tethering.TetheringConfiguration.DEFAULT_TETHER_OFFLOAD_POLL_INTERVAL_MS;
 
@@ -114,11 +116,42 @@
     private ConcurrentHashMap<String, ForwardedStats> mForwardedStats =
             new ConcurrentHashMap<>(16, 0.75F, 1);
 
+    private static class InterfaceQuota {
+        public final long warningBytes;
+        public final long limitBytes;
+
+        public static InterfaceQuota MAX_VALUE = new InterfaceQuota(Long.MAX_VALUE, Long.MAX_VALUE);
+
+        InterfaceQuota(long warningBytes, long limitBytes) {
+            this.warningBytes = warningBytes;
+            this.limitBytes = limitBytes;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) return true;
+            if (!(o instanceof InterfaceQuota)) return false;
+            InterfaceQuota that = (InterfaceQuota) o;
+            return warningBytes == that.warningBytes
+                    && limitBytes == that.limitBytes;
+        }
+
+        @Override
+        public int hashCode() {
+            return (int) (warningBytes * 3 + limitBytes * 5);
+        }
+
+        @Override
+        public String toString() {
+            return "InterfaceQuota{" + "warning=" + warningBytes + ", limit=" + limitBytes + '}';
+        }
+    }
+
     // Maps upstream interface names to interface quotas.
     // Always contains the latest value received from the framework for each interface, regardless
     // of whether offload is currently running (or is even supported) on that interface. Only
     // includes upstream interfaces that have a quota set.
-    private HashMap<String, Long> mInterfaceQuotas = new HashMap<>();
+    private HashMap<String, InterfaceQuota> mInterfaceQuotas = new HashMap<>();
 
     // Tracking remaining alert quota. Unlike limit quota is subject to interface, the alert
     // quota is interface independent and global for tether offload. Note that this is only
@@ -250,6 +283,18 @@
                     }
 
                     @Override
+                    public void onWarningReached() {
+                        if (!started()) return;
+                        mLog.log("onWarningReached");
+
+                        updateStatsForCurrentUpstream();
+                        if (mStatsProvider != null) {
+                            mStatsProvider.pushTetherStats();
+                            mStatsProvider.notifyWarningReached();
+                        }
+                    }
+
+                    @Override
                     public void onNatTimeoutUpdate(int proto,
                                                    String srcAddr, int srcPort,
                                                    String dstAddr, int dstPort) {
@@ -263,7 +308,8 @@
             mLog.i("tethering offload control not supported");
             stop();
         } else {
-            mLog.log("tethering offload started");
+            mLog.log("tethering offload started, version: "
+                    + OffloadHardwareInterface.halVerToString(mControlHalVersion));
             mNatUpdateCallbacksReceived = 0;
             mNatUpdateNetlinkErrors = 0;
             maybeSchedulePollingStats();
@@ -322,24 +368,35 @@
 
         @Override
         public void onSetLimit(String iface, long quotaBytes) {
+            onSetWarningAndLimit(iface, QUOTA_UNLIMITED, quotaBytes);
+        }
+
+        @Override
+        public void onSetWarningAndLimit(@NonNull String iface,
+                long warningBytes, long limitBytes) {
             // Listen for all iface is necessary since upstream might be changed after limit
             // is set.
             mHandler.post(() -> {
-                final Long curIfaceQuota = mInterfaceQuotas.get(iface);
+                final InterfaceQuota curIfaceQuota = mInterfaceQuotas.get(iface);
+                final InterfaceQuota newIfaceQuota = new InterfaceQuota(
+                        warningBytes == QUOTA_UNLIMITED ? Long.MAX_VALUE : warningBytes,
+                        limitBytes == QUOTA_UNLIMITED ? Long.MAX_VALUE : limitBytes);
 
                 // If the quota is set to unlimited, the value set to HAL is Long.MAX_VALUE,
                 // which is ~8.4 x 10^6 TiB, no one can actually reach it. Thus, it is not
                 // useful to set it multiple times.
                 // Otherwise, the quota needs to be updated to tell HAL to re-count from now even
                 // if the quota is the same as the existing one.
-                if (null == curIfaceQuota && QUOTA_UNLIMITED == quotaBytes) return;
+                if (null == curIfaceQuota && InterfaceQuota.MAX_VALUE.equals(newIfaceQuota)) {
+                    return;
+                }
 
-                if (quotaBytes == QUOTA_UNLIMITED) {
+                if (InterfaceQuota.MAX_VALUE.equals(newIfaceQuota)) {
                     mInterfaceQuotas.remove(iface);
                 } else {
-                    mInterfaceQuotas.put(iface, quotaBytes);
+                    mInterfaceQuotas.put(iface, newIfaceQuota);
                 }
-                maybeUpdateDataLimit(iface);
+                maybeUpdateDataWarningAndLimit(iface);
             });
         }
 
@@ -374,7 +431,11 @@
 
         @Override
         public void onSetAlert(long quotaBytes) {
-            // TODO: Ask offload HAL to notify alert without stopping traffic.
+            // Ignore set alert calls from HAL V1.1 since the hardware supports set warning now.
+            // Thus, the software polling mechanism is not needed.
+            if (!useStatsPolling()) {
+                return;
+            }
             // Post it to handler thread since it access remaining quota bytes.
             mHandler.post(() -> {
                 updateAlertQuota(quotaBytes);
@@ -459,24 +520,32 @@
 
     private boolean isPollingStatsNeeded() {
         return started() && mRemainingAlertQuota > 0
+                && useStatsPolling()
                 && !TextUtils.isEmpty(currentUpstreamInterface())
                 && mDeps.getTetherConfig() != null
                 && mDeps.getTetherConfig().getOffloadPollInterval()
                 >= DEFAULT_TETHER_OFFLOAD_POLL_INTERVAL_MS;
     }
 
-    private boolean maybeUpdateDataLimit(String iface) {
-        // setDataLimit may only be called while offload is occurring on this upstream.
+    private boolean useStatsPolling() {
+        return mControlHalVersion == OFFLOAD_HAL_VERSION_1_0;
+    }
+
+    private boolean maybeUpdateDataWarningAndLimit(String iface) {
+        // setDataLimit or setDataWarningAndLimit may only be called while offload is occurring
+        // on this upstream.
         if (!started() || !TextUtils.equals(iface, currentUpstreamInterface())) {
             return true;
         }
 
-        Long limit = mInterfaceQuotas.get(iface);
-        if (limit == null) {
-            limit = Long.MAX_VALUE;
+        final InterfaceQuota quota = mInterfaceQuotas.getOrDefault(iface, InterfaceQuota.MAX_VALUE);
+        final boolean ret;
+        if (mControlHalVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            ret = mHwInterface.setDataWarningAndLimit(iface, quota.warningBytes, quota.limitBytes);
+        } else {
+            ret = mHwInterface.setDataLimit(iface, quota.limitBytes);
         }
-
-        return mHwInterface.setDataLimit(iface, limit);
+        return ret;
     }
 
     private void updateStatsForCurrentUpstream() {
@@ -630,7 +699,7 @@
         maybeUpdateStats(prevUpstream);
 
         // Data limits can only be set once offload is running on the upstream.
-        success = maybeUpdateDataLimit(iface);
+        success = maybeUpdateDataWarningAndLimit(iface);
         if (!success) {
             // If we failed to set a data limit, don't use this upstream, because we don't want to
             // blow through the data limit that we were told to apply.
diff --git a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
index 7685847..e3ac660 100644
--- a/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
+++ b/Tethering/src/com/android/networkstack/tethering/OffloadHardwareInterface.java
@@ -24,10 +24,10 @@
 import android.annotation.NonNull;
 import android.hardware.tetheroffload.config.V1_0.IOffloadConfig;
 import android.hardware.tetheroffload.control.V1_0.IOffloadControl;
-import android.hardware.tetheroffload.control.V1_0.ITetheringOffloadCallback;
 import android.hardware.tetheroffload.control.V1_0.NatTimeoutUpdate;
 import android.hardware.tetheroffload.control.V1_0.NetworkProtocol;
 import android.hardware.tetheroffload.control.V1_0.OffloadCallbackEvent;
+import android.hardware.tetheroffload.control.V1_1.ITetheringOffloadCallback;
 import android.net.netlink.NetlinkSocket;
 import android.net.netlink.StructNfGenMsg;
 import android.net.netlink.StructNlMsgHdr;
@@ -39,6 +39,7 @@
 import android.system.ErrnoException;
 import android.system.Os;
 import android.system.OsConstants;
+import android.util.Log;
 import android.util.Pair;
 
 import com.android.internal.annotations.VisibleForTesting;
@@ -140,6 +141,8 @@
         public void onSupportAvailable() {}
         /** Offload stopped because of usage limit reached. */
         public void onStoppedLimitReached() {}
+        /** Indicate that data warning quota is reached. */
+        public void onWarningReached() {}
 
         /** Indicate to update NAT timeout. */
         public void onNatTimeoutUpdate(int proto,
@@ -381,7 +384,8 @@
                 (controlCb == null) ? "null"
                         : "0x" + Integer.toHexString(System.identityHashCode(controlCb)));
 
-        mTetheringOffloadCallback = new TetheringOffloadCallback(mHandler, mControlCallback, mLog);
+        mTetheringOffloadCallback = new TetheringOffloadCallback(
+                mHandler, mControlCallback, mLog, mOffloadControlVersion);
         final CbResults results = new CbResults();
         try {
             mOffloadControl.initOffload(
@@ -480,6 +484,33 @@
         return results.mSuccess;
     }
 
+    /** Set data warning and limit value to offload management process. */
+    public boolean setDataWarningAndLimit(String iface, long warning, long limit) {
+        if (mOffloadControlVersion < OFFLOAD_HAL_VERSION_1_1) {
+            throw new IllegalArgumentException(
+                    "setDataWarningAndLimit is not supported below HAL V1.1");
+        }
+        final String logmsg =
+                String.format("setDataWarningAndLimit(%s, %d, %d)", iface, warning, limit);
+
+        final CbResults results = new CbResults();
+        try {
+            ((android.hardware.tetheroffload.control.V1_1.IOffloadControl) mOffloadControl)
+                    .setDataWarningAndLimit(
+                            iface, warning, limit,
+                            (boolean success, String errMsg) -> {
+                                results.mSuccess = success;
+                                results.mErrMsg = errMsg;
+                            });
+        } catch (RemoteException e) {
+            record(logmsg, e);
+            return false;
+        }
+
+        record(logmsg, results);
+        return results.mSuccess;
+    }
+
     /** Set upstream parameters to offload management process. */
     public boolean setUpstreamParameters(
             String iface, String v4addr, String v4gateway, ArrayList<String> v6gws) {
@@ -565,35 +596,64 @@
         public final Handler handler;
         public final ControlCallback controlCb;
         public final SharedLog log;
+        private final int mOffloadControlVersion;
 
-        TetheringOffloadCallback(Handler h, ControlCallback cb, SharedLog sharedLog) {
+        TetheringOffloadCallback(
+                Handler h, ControlCallback cb, SharedLog sharedLog, int offloadControlVersion) {
             handler = h;
             controlCb = cb;
             log = sharedLog;
+            this.mOffloadControlVersion = offloadControlVersion;
+        }
+
+        private void handleOnEvent(int event) {
+            switch (event) {
+                case OffloadCallbackEvent.OFFLOAD_STARTED:
+                    controlCb.onStarted();
+                    break;
+                case OffloadCallbackEvent.OFFLOAD_STOPPED_ERROR:
+                    controlCb.onStoppedError();
+                    break;
+                case OffloadCallbackEvent.OFFLOAD_STOPPED_UNSUPPORTED:
+                    controlCb.onStoppedUnsupported();
+                    break;
+                case OffloadCallbackEvent.OFFLOAD_SUPPORT_AVAILABLE:
+                    controlCb.onSupportAvailable();
+                    break;
+                case OffloadCallbackEvent.OFFLOAD_STOPPED_LIMIT_REACHED:
+                    controlCb.onStoppedLimitReached();
+                    break;
+                case android.hardware.tetheroffload.control
+                        .V1_1.OffloadCallbackEvent.OFFLOAD_WARNING_REACHED:
+                    controlCb.onWarningReached();
+                    break;
+                default:
+                    log.e("Unsupported OffloadCallbackEvent: " + event);
+            }
         }
 
         @Override
         public void onEvent(int event) {
+            // The implementation should never call onEvent()) if the event is already reported
+            // through newer callback.
+            if (mOffloadControlVersion > OFFLOAD_HAL_VERSION_1_0) {
+                Log.wtf(TAG, "onEvent(" + event + ") fired on HAL "
+                        + halVerToString(mOffloadControlVersion));
+            }
             handler.post(() -> {
-                switch (event) {
-                    case OffloadCallbackEvent.OFFLOAD_STARTED:
-                        controlCb.onStarted();
-                        break;
-                    case OffloadCallbackEvent.OFFLOAD_STOPPED_ERROR:
-                        controlCb.onStoppedError();
-                        break;
-                    case OffloadCallbackEvent.OFFLOAD_STOPPED_UNSUPPORTED:
-                        controlCb.onStoppedUnsupported();
-                        break;
-                    case OffloadCallbackEvent.OFFLOAD_SUPPORT_AVAILABLE:
-                        controlCb.onSupportAvailable();
-                        break;
-                    case OffloadCallbackEvent.OFFLOAD_STOPPED_LIMIT_REACHED:
-                        controlCb.onStoppedLimitReached();
-                        break;
-                    default:
-                        log.e("Unsupported OffloadCallbackEvent: " + event);
-                }
+                handleOnEvent(event);
+            });
+        }
+
+        @Override
+        public void onEvent_1_1(int event) {
+            if (mOffloadControlVersion < OFFLOAD_HAL_VERSION_1_1) {
+                Log.wtf(TAG, "onEvent_1_1(" + event + ") fired on HAL "
+                        + halVerToString(mOffloadControlVersion));
+                return;
+            }
+            handler.post(() -> {
+                handleOnEvent(event);
             });
         }
 
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadControllerTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadControllerTest.java
index 88f2054..d800816 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadControllerTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadControllerTest.java
@@ -58,7 +58,6 @@
 import android.app.usage.NetworkStatsManager;
 import android.content.Context;
 import android.content.pm.ApplicationInfo;
-import android.net.ITetheringStatsProvider;
 import android.net.IpPrefix;
 import android.net.LinkAddress;
 import android.net.LinkProperties;
@@ -150,6 +149,7 @@
         when(mHardware.setUpstreamParameters(anyString(), any(), any(), any())).thenReturn(true);
         when(mHardware.getForwardedStats(any())).thenReturn(new ForwardedStats());
         when(mHardware.setDataLimit(anyString(), anyLong())).thenReturn(true);
+        when(mHardware.setDataWarningAndLimit(anyString(), anyLong(), anyLong())).thenReturn(true);
     }
 
     private void enableOffload() {
@@ -503,77 +503,167 @@
                 expectedUidStatsDiff);
     }
 
+    /**
+     * Test OffloadController with different combinations of HAL and framework versions can set
+     * data warning and/or limit correctly.
+     */
     @Test
-    public void testSetInterfaceQuota() throws Exception {
+    public void testSetDataWarningAndLimit() throws Exception {
+        // Verify the OffloadController is called by R framework, where the framework doesn't send
+        // warning.
+        checkSetDataWarningAndLimit(false, OFFLOAD_HAL_VERSION_1_0);
+        checkSetDataWarningAndLimit(false, OFFLOAD_HAL_VERSION_1_1);
+        // Verify the OffloadController is called by S+ framework, where the framework sends
+        // warning along with limit.
+        checkSetDataWarningAndLimit(true, OFFLOAD_HAL_VERSION_1_0);
+        checkSetDataWarningAndLimit(true, OFFLOAD_HAL_VERSION_1_1);
+    }
+
+    private void checkSetDataWarningAndLimit(boolean isProviderSetWarning, int controlVersion)
+            throws Exception {
         enableOffload();
         final OffloadController offload =
-                startOffloadController(OFFLOAD_HAL_VERSION_1_0, true /*expectStart*/);
+                startOffloadController(controlVersion, true /*expectStart*/);
 
         final String ethernetIface = "eth1";
         final String mobileIface = "rmnet_data0";
         final long ethernetLimit = 12345;
+        final long mobileWarning = 123456;
         final long mobileLimit = 12345678;
 
         final LinkProperties lp = new LinkProperties();
         lp.setInterfaceName(ethernetIface);
-        offload.setUpstreamLinkProperties(lp);
 
         final InOrder inOrder = inOrder(mHardware);
-        when(mHardware.setUpstreamParameters(any(), any(), any(), any())).thenReturn(true);
+        when(mHardware.setUpstreamParameters(
+                any(), any(), any(), any())).thenReturn(true);
         when(mHardware.setDataLimit(anyString(), anyLong())).thenReturn(true);
+        when(mHardware.setDataWarningAndLimit(anyString(), anyLong(), anyLong())).thenReturn(true);
+        offload.setUpstreamLinkProperties(lp);
+        // Applying an interface sends the initial quota to the hardware.
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            inOrder.verify(mHardware).setDataWarningAndLimit(ethernetIface, Long.MAX_VALUE,
+                    Long.MAX_VALUE);
+        } else {
+            inOrder.verify(mHardware).setDataLimit(ethernetIface, Long.MAX_VALUE);
+        }
+        inOrder.verifyNoMoreInteractions();
+
+        // Verify that set to unlimited again won't cause duplicated calls to the hardware.
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(ethernetIface,
+                    NetworkStatsProvider.QUOTA_UNLIMITED, NetworkStatsProvider.QUOTA_UNLIMITED);
+        } else {
+            mTetherStatsProvider.onSetLimit(ethernetIface, NetworkStatsProvider.QUOTA_UNLIMITED);
+        }
+        waitForIdle();
+        inOrder.verifyNoMoreInteractions();
 
         // Applying an interface quota to the current upstream immediately sends it to the hardware.
-        mTetherStatsProvider.onSetLimit(ethernetIface, ethernetLimit);
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(ethernetIface,
+                    NetworkStatsProvider.QUOTA_UNLIMITED, ethernetLimit);
+        } else {
+            mTetherStatsProvider.onSetLimit(ethernetIface, ethernetLimit);
+        }
         waitForIdle();
-        inOrder.verify(mHardware).setDataLimit(ethernetIface, ethernetLimit);
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            inOrder.verify(mHardware).setDataWarningAndLimit(ethernetIface, Long.MAX_VALUE,
+                    ethernetLimit);
+        } else {
+            inOrder.verify(mHardware).setDataLimit(ethernetIface, ethernetLimit);
+        }
         inOrder.verifyNoMoreInteractions();
 
         // Applying an interface quota to another upstream does not take any immediate action.
-        mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(mobileIface, mobileWarning, mobileLimit);
+        } else {
+            mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        }
         waitForIdle();
-        inOrder.verify(mHardware, never()).setDataLimit(anyString(), anyLong());
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            inOrder.verify(mHardware, never()).setDataWarningAndLimit(anyString(), anyLong(),
+                    anyLong());
+        } else {
+            inOrder.verify(mHardware, never()).setDataLimit(anyString(), anyLong());
+        }
 
         // Switching to that upstream causes the quota to be applied if the parameters were applied
         // correctly.
         lp.setInterfaceName(mobileIface);
         offload.setUpstreamLinkProperties(lp);
         waitForIdle();
-        inOrder.verify(mHardware).setDataLimit(mobileIface, mobileLimit);
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            inOrder.verify(mHardware).setDataWarningAndLimit(mobileIface,
+                    isProviderSetWarning ? mobileWarning : Long.MAX_VALUE,
+                    mobileLimit);
+        } else {
+            inOrder.verify(mHardware).setDataLimit(mobileIface, mobileLimit);
+        }
 
-        // Setting a limit of ITetheringStatsProvider.QUOTA_UNLIMITED causes the limit to be set
+        // Setting a limit of NetworkStatsProvider.QUOTA_UNLIMITED causes the limit to be set
         // to Long.MAX_VALUE.
-        mTetherStatsProvider.onSetLimit(mobileIface, ITetheringStatsProvider.QUOTA_UNLIMITED);
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(mobileIface,
+                    NetworkStatsProvider.QUOTA_UNLIMITED, NetworkStatsProvider.QUOTA_UNLIMITED);
+        } else {
+            mTetherStatsProvider.onSetLimit(mobileIface, NetworkStatsProvider.QUOTA_UNLIMITED);
+        }
         waitForIdle();
-        inOrder.verify(mHardware).setDataLimit(mobileIface, Long.MAX_VALUE);
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            inOrder.verify(mHardware).setDataWarningAndLimit(mobileIface, Long.MAX_VALUE,
+                    Long.MAX_VALUE);
+        } else {
+            inOrder.verify(mHardware).setDataLimit(mobileIface, Long.MAX_VALUE);
+        }
 
-        // If setting upstream parameters fails, then the data limit is not set.
+        // If setting upstream parameters fails, then the data warning and limit is not set.
         when(mHardware.setUpstreamParameters(any(), any(), any(), any())).thenReturn(false);
         lp.setInterfaceName(ethernetIface);
         offload.setUpstreamLinkProperties(lp);
-        mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(mobileIface, mobileWarning, mobileLimit);
+        } else {
+            mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        }
         waitForIdle();
         inOrder.verify(mHardware, never()).setDataLimit(anyString(), anyLong());
+        inOrder.verify(mHardware, never()).setDataWarningAndLimit(anyString(), anyLong(),
+                anyLong());
 
-        // If setting the data limit fails while changing upstreams, offload is stopped.
+        // If setting the data warning and/or limit fails while changing upstreams, offload is
+        // stopped.
         when(mHardware.setUpstreamParameters(any(), any(), any(), any())).thenReturn(true);
         when(mHardware.setDataLimit(anyString(), anyLong())).thenReturn(false);
+        when(mHardware.setDataWarningAndLimit(anyString(), anyLong(), anyLong())).thenReturn(false);
         lp.setInterfaceName(mobileIface);
         offload.setUpstreamLinkProperties(lp);
-        mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        if (isProviderSetWarning) {
+            mTetherStatsProvider.onSetWarningAndLimit(mobileIface, mobileWarning, mobileLimit);
+        } else {
+            mTetherStatsProvider.onSetLimit(mobileIface, mobileLimit);
+        }
         waitForIdle();
         inOrder.verify(mHardware).getForwardedStats(ethernetIface);
         inOrder.verify(mHardware).stopOffloadControl();
     }
 
     @Test
-    public void testDataLimitCallback() throws Exception {
+    public void testDataWarningAndLimitCallback() throws Exception {
         enableOffload();
-        final OffloadController offload =
-                startOffloadController(OFFLOAD_HAL_VERSION_1_0, true /*expectStart*/);
+        startOffloadController(OFFLOAD_HAL_VERSION_1_0, true /*expectStart*/);
 
         OffloadHardwareInterface.ControlCallback callback = mControlCallbackCaptor.getValue();
         callback.onStoppedLimitReached();
         mTetherStatsProviderCb.expectNotifyStatsUpdated();
+        mTetherStatsProviderCb.expectNotifyWarningOrLimitReached();
+
+        startOffloadController(OFFLOAD_HAL_VERSION_1_1, true /*expectStart*/);
+        callback = mControlCallbackCaptor.getValue();
+        callback.onWarningReached();
+        mTetherStatsProviderCb.expectNotifyStatsUpdated();
+        mTetherStatsProviderCb.expectNotifyWarningOrLimitReached();
     }
 
     @Test
@@ -761,9 +851,7 @@
         // Initialize with fake eth upstream.
         final String ethernetIface = "eth1";
         InOrder inOrder = inOrder(mHardware);
-        final LinkProperties lp = new LinkProperties();
-        lp.setInterfaceName(ethernetIface);
-        offload.setUpstreamLinkProperties(lp);
+        offload.setUpstreamLinkProperties(makeEthernetLinkProperties());
         // Previous upstream was null, so no stats are fetched.
         inOrder.verify(mHardware, never()).getForwardedStats(any());
 
@@ -796,4 +884,33 @@
         mTetherStatsProviderCb.assertNoCallback();
         verify(mHardware, never()).getForwardedStats(any());
     }
+
+    private static LinkProperties makeEthernetLinkProperties() {
+        final String ethernetIface = "eth1";
+        final LinkProperties lp = new LinkProperties();
+        lp.setInterfaceName(ethernetIface);
+        return lp;
+    }
+
+    private void checkSoftwarePollingUsed(int controlVersion) throws Exception {
+        enableOffload();
+        setOffloadPollInterval(DEFAULT_TETHER_OFFLOAD_POLL_INTERVAL_MS);
+        OffloadController offload =
+                startOffloadController(controlVersion, true /*expectStart*/);
+        offload.setUpstreamLinkProperties(makeEthernetLinkProperties());
+        mTetherStatsProvider.onSetAlert(0);
+        waitForIdle();
+        if (controlVersion >= OFFLOAD_HAL_VERSION_1_1) {
+            mTetherStatsProviderCb.assertNoCallback();
+        } else {
+            mTetherStatsProviderCb.expectNotifyAlertReached();
+        }
+        verify(mHardware, never()).getForwardedStats(any());
+    }
+
+    @Test
+    public void testSoftwarePollingUsed() throws Exception {
+        checkSoftwarePollingUsed(OFFLOAD_HAL_VERSION_1_0);
+        checkSoftwarePollingUsed(OFFLOAD_HAL_VERSION_1_1);
+    }
 }
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadHardwareInterfaceTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadHardwareInterfaceTest.java
index f4194e5..a8b3b92 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadHardwareInterfaceTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/OffloadHardwareInterfaceTest.java
@@ -22,22 +22,27 @@
 import static android.system.OsConstants.SOCK_STREAM;
 
 import static com.android.networkstack.tethering.OffloadHardwareInterface.OFFLOAD_HAL_VERSION_1_0;
+import static com.android.networkstack.tethering.OffloadHardwareInterface.OFFLOAD_HAL_VERSION_1_1;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.inOrder;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 import android.hardware.tetheroffload.config.V1_0.IOffloadConfig;
 import android.hardware.tetheroffload.control.V1_0.IOffloadControl;
-import android.hardware.tetheroffload.control.V1_0.ITetheringOffloadCallback;
 import android.hardware.tetheroffload.control.V1_0.NatTimeoutUpdate;
 import android.hardware.tetheroffload.control.V1_0.NetworkProtocol;
-import android.hardware.tetheroffload.control.V1_0.OffloadCallbackEvent;
+import android.hardware.tetheroffload.control.V1_1.ITetheringOffloadCallback;
+import android.hardware.tetheroffload.control.V1_1.OffloadCallbackEvent;
 import android.net.netlink.StructNfGenMsg;
 import android.net.netlink.StructNlMsgHdr;
 import android.net.util.SharedLog;
@@ -56,6 +61,7 @@
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
+import org.mockito.InOrder;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -76,7 +82,7 @@
     private OffloadHardwareInterface.ControlCallback mControlCallback;
 
     @Mock private IOffloadConfig mIOffloadConfig;
-    @Mock private IOffloadControl mIOffloadControl;
+    private IOffloadControl mIOffloadControl;
     @Mock private NativeHandle mNativeHandle;
 
     // Random values to test Netlink message.
@@ -84,8 +90,10 @@
     private static final short TEST_FLAGS = 263;
 
     class MyDependencies extends OffloadHardwareInterface.Dependencies {
-        MyDependencies(SharedLog log) {
+        private final int mMockControlVersion;
+        MyDependencies(SharedLog log, final int mockControlVersion) {
             super(log);
+            mMockControlVersion = mockControlVersion;
         }
 
         @Override
@@ -95,7 +103,19 @@
 
         @Override
         public Pair<IOffloadControl, Integer> getOffloadControl() {
-            return new Pair<IOffloadControl, Integer>(mIOffloadControl, OFFLOAD_HAL_VERSION_1_0);
+            switch (mMockControlVersion) {
+                case OFFLOAD_HAL_VERSION_1_0:
+                    mIOffloadControl = mock(IOffloadControl.class);
+                    break;
+                case OFFLOAD_HAL_VERSION_1_1:
+                    mIOffloadControl =
+                            mock(android.hardware.tetheroffload.control.V1_1.IOffloadControl.class);
+                    break;
+                default:
+                    throw new IllegalArgumentException("Invalid offload control version "
+                            + mMockControlVersion);
+            }
+            return new Pair<IOffloadControl, Integer>(mIOffloadControl, mMockControlVersion);
         }
 
         @Override
@@ -107,14 +127,13 @@
     @Before
     public void setUp() {
         MockitoAnnotations.initMocks(this);
-        final SharedLog log = new SharedLog("test");
-        mOffloadHw = new OffloadHardwareInterface(new Handler(mTestLooper.getLooper()), log,
-                new MyDependencies(log));
         mControlCallback = spy(new OffloadHardwareInterface.ControlCallback());
     }
 
-    // TODO: Pass version to test version specific operations.
-    private void startOffloadHardwareInterface() throws Exception {
+    private void startOffloadHardwareInterface(int controlVersion) throws Exception {
+        final SharedLog log = new SharedLog("test");
+        mOffloadHw = new OffloadHardwareInterface(new Handler(mTestLooper.getLooper()), log,
+                new MyDependencies(log, controlVersion));
         mOffloadHw.initOffloadConfig();
         mOffloadHw.initOffloadControl(mControlCallback);
         final ArgumentCaptor<ITetheringOffloadCallback> mOffloadCallbackCaptor =
@@ -125,7 +144,7 @@
 
     @Test
     public void testGetForwardedStats() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         final OffloadHardwareInterface.ForwardedStats stats = mOffloadHw.getForwardedStats(RMNET0);
         verify(mIOffloadControl).getForwardedStats(eq(RMNET0), any());
         assertNotNull(stats);
@@ -133,7 +152,7 @@
 
     @Test
     public void testSetLocalPrefixes() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         final ArrayList<String> localPrefixes = new ArrayList<>();
         localPrefixes.add("127.0.0.0/8");
         localPrefixes.add("fe80::/64");
@@ -143,15 +162,32 @@
 
     @Test
     public void testSetDataLimit() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         final long limit = 12345;
         mOffloadHw.setDataLimit(RMNET0, limit);
         verify(mIOffloadControl).setDataLimit(eq(RMNET0), eq(limit), any());
     }
 
     @Test
+    public void testSetDataWarningAndLimit() throws Exception {
+        // Verify V1.0 control HAL would reject the function call with exception.
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
+        final long warning = 12345;
+        final long limit = 67890;
+        assertThrows(IllegalArgumentException.class,
+                () -> mOffloadHw.setDataWarningAndLimit(RMNET0, warning, limit));
+        reset(mIOffloadControl);
+
+        // Verify V1.1 control HAL could receive this function call.
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_1);
+        mOffloadHw.setDataWarningAndLimit(RMNET0, warning, limit);
+        verify((android.hardware.tetheroffload.control.V1_1.IOffloadControl) mIOffloadControl)
+                .setDataWarningAndLimit(eq(RMNET0), eq(warning), eq(limit), any());
+    }
+
+    @Test
     public void testSetUpstreamParameters() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         final String v4addr = "192.168.10.1";
         final String v4gateway = "192.168.10.255";
         final ArrayList<String> v6gws = new ArrayList<>(0);
@@ -170,7 +206,7 @@
 
     @Test
     public void testUpdateDownstreamPrefix() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         final String ifName = "wlan1";
         final String prefix = "192.168.43.0/24";
         mOffloadHw.addDownstreamPrefix(ifName, prefix);
@@ -182,7 +218,7 @@
 
     @Test
     public void testTetheringOffloadCallback() throws Exception {
-        startOffloadHardwareInterface();
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
 
         mTetheringOffloadCallback.onEvent(OffloadCallbackEvent.OFFLOAD_STARTED);
         mTestLooper.dispatchAll();
@@ -221,10 +257,26 @@
                 eq(uint16(udpParams.src.port)),
                 eq(udpParams.dst.addr),
                 eq(uint16(udpParams.dst.port)));
+        reset(mControlCallback);
+
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_1);
+
+        // Verify the interface will process the events that comes from V1.1 HAL.
+        mTetheringOffloadCallback.onEvent_1_1(OffloadCallbackEvent.OFFLOAD_STARTED);
+        mTestLooper.dispatchAll();
+        final InOrder inOrder = inOrder(mControlCallback);
+        inOrder.verify(mControlCallback).onStarted();
+        inOrder.verifyNoMoreInteractions();
+
+        mTetheringOffloadCallback.onEvent_1_1(OffloadCallbackEvent.OFFLOAD_WARNING_REACHED);
+        mTestLooper.dispatchAll();
+        inOrder.verify(mControlCallback).onWarningReached();
+        inOrder.verifyNoMoreInteractions();
     }
 
     @Test
     public void testSendIpv4NfGenMsg() throws Exception {
+        startOffloadHardwareInterface(OFFLOAD_HAL_VERSION_1_0);
         FileDescriptor writeSocket = new FileDescriptor();
         FileDescriptor readSocket = new FileDescriptor();
         try {
diff --git a/service/src/com/android/server/connectivity/NetworkRanker.java b/service/src/com/android/server/connectivity/NetworkRanker.java
index e839837..d7eb9c8 100644
--- a/service/src/com/android/server/connectivity/NetworkRanker.java
+++ b/service/src/com/android/server/connectivity/NetworkRanker.java
@@ -108,7 +108,58 @@
         }
     }
 
-    @Nullable private <T extends Scoreable> T getBestNetworkByPolicy(
+    private <T extends Scoreable> boolean isBadWiFi(@NonNull final T candidate) {
+        return candidate.getScore().hasPolicy(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD)
+                && candidate.getCapsNoCopy().hasTransport(TRANSPORT_WIFI);
+    }
+
+    /**
+     * Apply the "yield to bad WiFi" policy.
+     *
+     * This function must run immediately after the validation policy.
+     *
+     * If any of the accepted networks has the "yield to bad WiFi" policy AND there are some
+     * bad WiFis in the rejected list, then move the networks with the policy to the rejected
+     * list. If this leaves no accepted network, then move the bad WiFis back to the accepted list.
+     *
+     * This function returns nothing, but will have updated accepted and rejected in-place.
+     *
+     * @param accepted networks accepted by the validation policy
+     * @param rejected networks rejected by the validation policy
+     */
+    private <T extends Scoreable> void applyYieldToBadWifiPolicy(@NonNull ArrayList<T> accepted,
+            @NonNull ArrayList<T> rejected) {
+        if (!CollectionUtils.any(accepted, n -> n.getScore().hasPolicy(POLICY_YIELD_TO_BAD_WIFI))) {
+            // No network with the policy : do nothing.
+            return;
+        }
+        if (!CollectionUtils.any(rejected, n -> isBadWiFi(n))) {
+            // No bad WiFi : do nothing.
+            return;
+        }
+        if (CollectionUtils.all(accepted, n -> n.getScore().hasPolicy(POLICY_YIELD_TO_BAD_WIFI))) {
+            // All validated networks yield to bad WiFis : keep bad WiFis alongside with the
+            // yielders. This is important because the yielders need to be compared to the bad
+            // wifis by the following policies (e.g. exiting).
+            final ArrayList<T> acceptedYielders = new ArrayList<>(accepted);
+            final ArrayList<T> rejectedWithBadWiFis = new ArrayList<>(rejected);
+            partitionInto(rejectedWithBadWiFis, n -> isBadWiFi(n), accepted, rejected);
+            accepted.addAll(acceptedYielders);
+            return;
+        }
+        // Only some of the validated networks yield to bad WiFi : keep only the ones who don't.
+        final ArrayList<T> acceptedWithYielders = new ArrayList<>(accepted);
+        partitionInto(acceptedWithYielders, n -> !n.getScore().hasPolicy(POLICY_YIELD_TO_BAD_WIFI),
+                accepted, rejected);
+    }
+
+    /**
+     * Get the best network among a list of candidates according to policy.
+     * @param candidates the candidates
+     * @param currentSatisfier the current satisfier, or null if none
+     * @return the best network
+     */
+    @Nullable public <T extends Scoreable> T getBestNetworkByPolicy(
             @NonNull List<T> candidates,
             @Nullable final T currentSatisfier) {
         // Used as working areas.
@@ -148,24 +199,15 @@
         if (accepted.size() == 1) return accepted.get(0);
         if (accepted.size() > 0 && rejected.size() > 0) candidates = new ArrayList<>(accepted);
 
-        // Yield to bad wifi policy : if any wifi has ever been validated (even if it's now
-        // unvalidated), and unless it's been explicitly avoided when bad in UI, then keep only
-        // networks that don't yield to such a wifi network.
-        final boolean anyWiFiEverValidated = CollectionUtils.any(candidates,
-                nai -> nai.getScore().hasPolicy(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD)
-                        && nai.getCapsNoCopy().hasTransport(TRANSPORT_WIFI));
-        if (anyWiFiEverValidated) {
-            partitionInto(candidates, nai -> !nai.getScore().hasPolicy(POLICY_YIELD_TO_BAD_WIFI),
-                    accepted, rejected);
-            if (accepted.size() == 1) return accepted.get(0);
-            if (accepted.size() > 0 && rejected.size() > 0) candidates = new ArrayList<>(accepted);
-        }
-
         // If any network is validated (or should be accepted even if it's not validated), then
         // don't choose one that isn't.
         partitionInto(candidates, nai -> nai.getScore().hasPolicy(POLICY_IS_VALIDATED)
                         || nai.getScore().hasPolicy(POLICY_ACCEPT_UNVALIDATED),
                 accepted, rejected);
+        // Yield to bad wifi policy : if any network has the "yield to bad WiFi" policy and
+        // there are bad WiFis connected, then accept the bad WiFis and reject the networks with
+        // the policy.
+        applyYieldToBadWifiPolicy(accepted, rejected);
         if (accepted.size() == 1) return accepted.get(0);
         if (accepted.size() > 0 && rejected.size() > 0) candidates = new ArrayList<>(accepted);
 
@@ -194,16 +236,26 @@
         // subscription with the same transport.
         partitionInto(candidates, nai -> nai.getScore().hasPolicy(POLICY_TRANSPORT_PRIMARY),
                 accepted, rejected);
-        for (final Scoreable defaultSubNai : accepted) {
-            // Remove all networks without the DEFAULT_SUBSCRIPTION policy and the same transports
-            // as a network that has it.
-            final int[] transports = defaultSubNai.getCapsNoCopy().getTransportTypes();
-            candidates.removeIf(nai -> !nai.getScore().hasPolicy(POLICY_TRANSPORT_PRIMARY)
-                    && Arrays.equals(transports, nai.getCapsNoCopy().getTransportTypes()));
+        if (accepted.size() > 0) {
+            // Some networks are primary for their transport. For each transport, keep only the
+            // primary, but also keep all networks for which there isn't a primary (which are now
+            // in the |rejected| array).
+            // So for each primary network, remove from |rejected| all networks with the same
+            // transports as one of the primary networks. The remaining networks should be accepted.
+            for (final T defaultSubNai : accepted) {
+                final int[] transports = defaultSubNai.getCapsNoCopy().getTransportTypes();
+                rejected.removeIf(
+                        nai -> Arrays.equals(transports, nai.getCapsNoCopy().getTransportTypes()));
+            }
+            // Now the |rejected| list contains networks with transports for which there isn't
+            // a primary network. Add them back to the candidates.
+            accepted.addAll(rejected);
+            candidates = new ArrayList<>(accepted);
         }
         if (1 == candidates.size()) return candidates.get(0);
-        // It's guaranteed candidates.size() > 0 because there is at least one with the
-        // TRANSPORT_PRIMARY policy and only those without it were removed.
+        // If there were no primary network, then candidates.size() > 0 because it didn't
+        // change from the previous result. If there were, it's guaranteed candidates.size() > 0
+        // because accepted.size() > 0 above.
 
         // If some of the networks have a better transport than others, keep only the ones with
         // the best transports.
diff --git a/tests/unit/java/android/net/ConnectivityManagerTest.java b/tests/unit/java/android/net/ConnectivityManagerTest.java
index 07f22a2..b8cd3f6 100644
--- a/tests/unit/java/android/net/ConnectivityManagerTest.java
+++ b/tests/unit/java/android/net/ConnectivityManagerTest.java
@@ -320,26 +320,34 @@
         NetworkCallback nullCallback = null;
         PendingIntent nullIntent = null;
 
-        mustFail(() -> { manager.requestNetwork(null, callback); });
-        mustFail(() -> { manager.requestNetwork(request, nullCallback); });
-        mustFail(() -> { manager.requestNetwork(request, callback, null); });
-        mustFail(() -> { manager.requestNetwork(request, callback, -1); });
-        mustFail(() -> { manager.requestNetwork(request, nullIntent); });
+        mustFail(() -> manager.requestNetwork(null, callback));
+        mustFail(() -> manager.requestNetwork(request, nullCallback));
+        mustFail(() -> manager.requestNetwork(request, callback, null));
+        mustFail(() -> manager.requestNetwork(request, callback, -1));
+        mustFail(() -> manager.requestNetwork(request, nullIntent));
 
-        mustFail(() -> { manager.registerNetworkCallback(null, callback, handler); });
-        mustFail(() -> { manager.registerNetworkCallback(request, null, handler); });
-        mustFail(() -> { manager.registerNetworkCallback(request, callback, null); });
-        mustFail(() -> { manager.registerNetworkCallback(request, nullIntent); });
+        mustFail(() -> manager.requestBackgroundNetwork(null, callback, handler));
+        mustFail(() -> manager.requestBackgroundNetwork(request, null, handler));
+        mustFail(() -> manager.requestBackgroundNetwork(request, callback, null));
 
-        mustFail(() -> { manager.registerDefaultNetworkCallback(null, handler); });
-        mustFail(() -> { manager.registerDefaultNetworkCallback(callback, null); });
+        mustFail(() -> manager.registerNetworkCallback(null, callback, handler));
+        mustFail(() -> manager.registerNetworkCallback(request, null, handler));
+        mustFail(() -> manager.registerNetworkCallback(request, callback, null));
+        mustFail(() -> manager.registerNetworkCallback(request, nullIntent));
 
-        mustFail(() -> { manager.registerSystemDefaultNetworkCallback(null, handler); });
-        mustFail(() -> { manager.registerSystemDefaultNetworkCallback(callback, null); });
+        mustFail(() -> manager.registerDefaultNetworkCallback(null, handler));
+        mustFail(() -> manager.registerDefaultNetworkCallback(callback, null));
 
-        mustFail(() -> { manager.unregisterNetworkCallback(nullCallback); });
-        mustFail(() -> { manager.unregisterNetworkCallback(nullIntent); });
-        mustFail(() -> { manager.releaseNetworkRequest(nullIntent); });
+        mustFail(() -> manager.registerSystemDefaultNetworkCallback(null, handler));
+        mustFail(() -> manager.registerSystemDefaultNetworkCallback(callback, null));
+
+        mustFail(() -> manager.registerBestMatchingNetworkCallback(null, callback, handler));
+        mustFail(() -> manager.registerBestMatchingNetworkCallback(request, null, handler));
+        mustFail(() -> manager.registerBestMatchingNetworkCallback(request, callback, null));
+
+        mustFail(() -> manager.unregisterNetworkCallback(nullCallback));
+        mustFail(() -> manager.unregisterNetworkCallback(nullIntent));
+        mustFail(() -> manager.releaseNetworkRequest(nullIntent));
     }
 
     static void mustFail(Runnable fn) {
diff --git a/tests/unit/java/com/android/server/connectivity/NetworkRankerTest.kt b/tests/unit/java/com/android/server/connectivity/NetworkRankerTest.kt
index 551b94c..4408958 100644
--- a/tests/unit/java/com/android/server/connectivity/NetworkRankerTest.kt
+++ b/tests/unit/java/com/android/server/connectivity/NetworkRankerTest.kt
@@ -1,5 +1,5 @@
 /*
- * Copyright (C) 2020 The Android Open Source Project
+ * Copyright (C) 2021 The Android Open Source Project
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,74 +17,156 @@
 package com.android.server.connectivity
 
 import android.net.NetworkCapabilities
-import android.net.NetworkRequest
+import android.net.NetworkCapabilities.TRANSPORT_CELLULAR
+import android.net.NetworkCapabilities.TRANSPORT_WIFI
 import android.net.NetworkScore.KEEP_CONNECTED_NONE
+import android.net.NetworkScore.POLICY_EXITING
+import android.net.NetworkScore.POLICY_TRANSPORT_PRIMARY
+import android.net.NetworkScore.POLICY_YIELD_TO_BAD_WIFI
+import android.os.Build
 import androidx.test.filters.SmallTest
-import androidx.test.runner.AndroidJUnit4
+import com.android.server.connectivity.FullScore.POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD
+import com.android.server.connectivity.FullScore.POLICY_IS_VALIDATED
+import com.android.testutils.DevSdkIgnoreRule
+import com.android.testutils.DevSdkIgnoreRunner
 import org.junit.Test
 import org.junit.runner.RunWith
-import org.mockito.ArgumentMatchers.any
-import org.mockito.Mockito.doReturn
-import org.mockito.Mockito.mock
 import kotlin.test.assertEquals
-import kotlin.test.assertNull
 
-@RunWith(AndroidJUnit4::class)
+private fun score(vararg policies: Int) = FullScore(0,
+        policies.fold(0L) { acc, e -> acc or (1L shl e) }, KEEP_CONNECTED_NONE)
+private fun caps(transport: Int) = NetworkCapabilities.Builder().addTransportType(transport).build()
+
 @SmallTest
+@RunWith(DevSdkIgnoreRunner::class)
+@DevSdkIgnoreRule.IgnoreUpTo(Build.VERSION_CODES.R)
 class NetworkRankerTest {
-    private val ranker = NetworkRanker()
+    private val mRanker = NetworkRanker()
 
-    private fun makeNai(satisfy: Boolean, legacyScore: Int) =
-            mock(NetworkAgentInfo::class.java).also {
-                doReturn(satisfy).`when`(it).satisfies(any())
-                val fs = FullScore(legacyScore, 0 /* policies */, KEEP_CONNECTED_NONE)
-                doReturn(fs).`when`(it).getScore()
-                val nc = NetworkCapabilities.Builder().build()
-                doReturn(nc).`when`(it).getCapsNoCopy()
-            }
-
-    @Test
-    fun testGetBestNetwork() {
-        val scores = listOf(20, 50, 90, 60, 23, 68)
-        val nais = scores.map { makeNai(true, it) }
-        val bestNetwork = nais[2] // The one with the top score
-        val someRequest = mock(NetworkRequest::class.java)
-        assertEquals(bestNetwork, ranker.getBestNetwork(someRequest, nais, bestNetwork))
+    private class TestScore(private val sc: FullScore, private val nc: NetworkCapabilities)
+            : NetworkRanker.Scoreable {
+        override fun getScore() = sc
+        override fun getCapsNoCopy(): NetworkCapabilities = nc
     }
 
     @Test
-    fun testIgnoreNonSatisfying() {
-        val nais = listOf(makeNai(true, 20), makeNai(true, 50), makeNai(false, 90),
-                makeNai(false, 60), makeNai(true, 23), makeNai(false, 68))
-        val bestNetwork = nais[1] // Top score that's satisfying
-        val someRequest = mock(NetworkRequest::class.java)
-        assertEquals(bestNetwork, ranker.getBestNetwork(someRequest, nais, nais[1]))
+    fun testYieldToBadWiFiOneCell() {
+        // Only cell, it wins
+        val winner = TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                caps(TRANSPORT_CELLULAR))
+        val scores = listOf(winner)
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
     }
 
     @Test
-    fun testNoMatch() {
-        val nais = listOf(makeNai(false, 20), makeNai(false, 50), makeNai(false, 90))
-        val someRequest = mock(NetworkRequest::class.java)
-        assertNull(ranker.getBestNetwork(someRequest, nais, null))
+    fun testYieldToBadWiFiOneCellOneBadWiFi() {
+        // Bad wifi wins against yielding validated cell
+        val winner = TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD),
+                caps(TRANSPORT_WIFI))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
     }
 
     @Test
-    fun testEmpty() {
-        val someRequest = mock(NetworkRequest::class.java)
-        assertNull(ranker.getBestNetwork(someRequest, emptyList(), null))
+    fun testYieldToBadWiFiOneCellTwoBadWiFi() {
+        // Bad wifi wins against yielding validated cell. Prefer the one that's primary.
+        val winner = TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                POLICY_TRANSPORT_PRIMARY), caps(TRANSPORT_WIFI))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD),
+                        caps(TRANSPORT_WIFI)),
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
     }
 
-    // Make sure the ranker is "stable" (as in stable sort), that is, it always returns the FIRST
-    // network satisfying the request if multiple of them have the same score.
     @Test
-    fun testStable() {
-        val nais1 = listOf(makeNai(true, 30), makeNai(true, 30), makeNai(true, 30),
-                makeNai(true, 30), makeNai(true, 30), makeNai(true, 30))
-        val someRequest = mock(NetworkRequest::class.java)
-        assertEquals(nais1[0], ranker.getBestNetwork(someRequest, nais1, nais1[0]))
+    fun testYieldToBadWiFiOneCellTwoBadWiFiOneNotAvoided() {
+        // Bad wifi ever validated wins against bad wifi that never was validated (or was
+        // avoided when bad).
+        val winner = TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD),
+                caps(TRANSPORT_WIFI))
+        val scores = listOf(
+                winner,
+                TestScore(score(), caps(TRANSPORT_WIFI)),
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
+    }
 
-        val nais2 = listOf(makeNai(true, 30), makeNai(true, 50), makeNai(true, 20),
-                makeNai(true, 50), makeNai(true, 50), makeNai(true, 40))
-        assertEquals(nais2[1], ranker.getBestNetwork(someRequest, nais2, nais2[1]))
+    @Test
+    fun testYieldToBadWiFiOneCellOneBadWiFiOneGoodWiFi() {
+        // Good wifi wins
+        val winner = TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                POLICY_IS_VALIDATED), caps(TRANSPORT_WIFI))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                        POLICY_TRANSPORT_PRIMARY), caps(TRANSPORT_WIFI)),
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
+    }
+
+    @Test
+    fun testYieldToBadWiFiTwoCellsOneBadWiFi() {
+        // Cell that doesn't yield wins over cell that yields and bad wifi
+        val winner = TestScore(score(POLICY_IS_VALIDATED), caps(TRANSPORT_CELLULAR))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                        POLICY_TRANSPORT_PRIMARY), caps(TRANSPORT_WIFI)),
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
+    }
+
+    @Test
+    fun testYieldToBadWiFiTwoCellsOneBadWiFiOneGoodWiFi() {
+        // Good wifi wins over cell that doesn't yield and cell that yields
+        val winner = TestScore(score(POLICY_IS_VALIDATED), caps(TRANSPORT_WIFI))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                        POLICY_TRANSPORT_PRIMARY), caps(TRANSPORT_WIFI)),
+                TestScore(score(POLICY_IS_VALIDATED), caps(TRANSPORT_CELLULAR)),
+                TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                        caps(TRANSPORT_CELLULAR))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
+    }
+
+    @Test
+    fun testYieldToBadWiFiOneExitingGoodWiFi() {
+        // Yielding cell wins over good exiting wifi
+        val winner = TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                caps(TRANSPORT_CELLULAR))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_IS_VALIDATED, POLICY_EXITING), caps(TRANSPORT_WIFI))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
+    }
+
+    @Test
+    fun testYieldToBadWiFiOneExitingBadWiFi() {
+        // Yielding cell wins over bad exiting wifi
+        val winner = TestScore(score(POLICY_YIELD_TO_BAD_WIFI, POLICY_IS_VALIDATED),
+                caps(TRANSPORT_CELLULAR))
+        val scores = listOf(
+                winner,
+                TestScore(score(POLICY_EVER_VALIDATED_NOT_AVOIDED_WHEN_BAD,
+                        POLICY_EXITING), caps(TRANSPORT_WIFI))
+        )
+        assertEquals(winner, mRanker.getBestNetworkByPolicy(scores, null))
     }
 }