Merge "Select Thread channel from preferred or supported channel mask" into main
diff --git a/thread/framework/java/android/net/thread/ThreadNetworkManager.java b/thread/framework/java/android/net/thread/ThreadNetworkManager.java
index b584487..150b759 100644
--- a/thread/framework/java/android/net/thread/ThreadNetworkManager.java
+++ b/thread/framework/java/android/net/thread/ThreadNetworkManager.java
@@ -83,8 +83,8 @@
      * This user restriction specifies if Thread network is disallowed on the device. If Thread
      * network is disallowed it cannot be turned on via Settings.
      *
-     * <p>this is a mirror of {@link UserManager#DISALLOW_THREAD_NETWORK} which is not available
-     * on Android U devices.
+     * <p>this is a mirror of {@link UserManager#DISALLOW_THREAD_NETWORK} which is not available on
+     * Android U devices.
      *
      * @hide
      */
diff --git a/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
new file mode 100644
index 0000000..e3b4e1a
--- /dev/null
+++ b/thread/service/java/com/android/server/thread/ActiveOperationalDatasetReceiverWrapper.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2024 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.android.server.thread;
+
+import static android.net.thread.ThreadNetworkException.ERROR_UNAVAILABLE;
+
+import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.IActiveOperationalDatasetReceiver;
+import android.os.RemoteException;
+
+import com.android.internal.annotations.GuardedBy;
+
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * A {@link IActiveOperationalDatasetReceiver} wrapper which makes it easier to invoke the
+ * callbacks.
+ */
+final class ActiveOperationalDatasetReceiverWrapper {
+    private final IActiveOperationalDatasetReceiver mReceiver;
+
+    private static final Object sPendingReceiversLock = new Object();
+
+    @GuardedBy("sPendingReceiversLock")
+    private static final Set<ActiveOperationalDatasetReceiverWrapper> sPendingReceivers =
+            new HashSet<>();
+
+    public ActiveOperationalDatasetReceiverWrapper(IActiveOperationalDatasetReceiver receiver) {
+        this.mReceiver = receiver;
+
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.add(this);
+        }
+    }
+
+    public static void onOtDaemonDied() {
+        synchronized (sPendingReceiversLock) {
+            for (ActiveOperationalDatasetReceiverWrapper receiver : sPendingReceivers) {
+                try {
+                    receiver.mReceiver.onError(ERROR_UNAVAILABLE, "Thread daemon died");
+                } catch (RemoteException e) {
+                    // The client is dead, do nothing
+                }
+            }
+            sPendingReceivers.clear();
+        }
+    }
+
+    public void onSuccess(ActiveOperationalDataset dataset) {
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.remove(this);
+        }
+
+        try {
+            mReceiver.onSuccess(dataset);
+        } catch (RemoteException e) {
+            // The client is dead, do nothing
+        }
+    }
+
+    public void onError(int errorCode, String errorMessage) {
+        synchronized (sPendingReceiversLock) {
+            sPendingReceivers.remove(this);
+        }
+
+        try {
+            mReceiver.onError(errorCode, errorMessage);
+        } catch (RemoteException e) {
+            // The client is dead, do nothing
+        }
+    }
+}
diff --git a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
index 01d1179..56dd056 100644
--- a/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
+++ b/thread/service/java/com/android/server/thread/ThreadNetworkControllerService.java
@@ -109,6 +109,7 @@
 import com.android.internal.annotations.VisibleForTesting;
 import com.android.server.ServiceManagerWrapper;
 import com.android.server.thread.openthread.BorderRouterConfigurationParcel;
+import com.android.server.thread.openthread.IChannelMasksReceiver;
 import com.android.server.thread.openthread.IOtDaemon;
 import com.android.server.thread.openthread.IOtDaemonCallback;
 import com.android.server.thread.openthread.IOtStatusReceiver;
@@ -157,9 +158,6 @@
     private final NsdPublisher mNsdPublisher;
     private final OtDaemonCallbackProxy mOtDaemonCallbackProxy = new OtDaemonCallbackProxy();
 
-    // TODO(b/308310823): read supported channel from Thread dameon
-    private final int mSupportedChannelMask = 0x07FFF800; // from channel 11 to 26
-
     @Nullable private IOtDaemon mOtDaemon;
     @Nullable private NetworkAgent mNetworkAgent;
     @Nullable private NetworkAgent mTestNetworkAgent;
@@ -593,26 +591,51 @@
     @Override
     public void createRandomizedDataset(
             String networkName, IActiveOperationalDatasetReceiver receiver) {
-        mHandler.post(
-                () -> {
-                    ActiveOperationalDataset dataset =
-                            createRandomizedDatasetInternal(
-                                    networkName,
-                                    mSupportedChannelMask,
-                                    Instant.now(),
-                                    new Random(),
-                                    new SecureRandom());
-                    try {
-                        receiver.onSuccess(dataset);
-                    } catch (RemoteException e) {
-                        // The client is dead, do nothing
-                    }
-                });
+        ActiveOperationalDatasetReceiverWrapper receiverWrapper =
+                new ActiveOperationalDatasetReceiverWrapper(receiver);
+        mHandler.post(() -> createRandomizedDatasetInternal(networkName, receiverWrapper));
     }
 
-    private static ActiveOperationalDataset createRandomizedDatasetInternal(
+    private void createRandomizedDatasetInternal(
+            String networkName, @NonNull ActiveOperationalDatasetReceiverWrapper receiver) {
+        checkOnHandlerThread();
+
+        try {
+            getOtDaemon().getChannelMasks(newChannelMasksReceiver(networkName, receiver));
+        } catch (RemoteException e) {
+            Log.e(TAG, "otDaemon.getChannelMasks failed", e);
+            receiver.onError(ERROR_INTERNAL_ERROR, "Thread stack error");
+        }
+    }
+
+    private IChannelMasksReceiver newChannelMasksReceiver(
+            String networkName, ActiveOperationalDatasetReceiverWrapper receiver) {
+        return new IChannelMasksReceiver.Stub() {
+            @Override
+            public void onSuccess(int supportedChannelMask, int preferredChannelMask) {
+                ActiveOperationalDataset dataset =
+                        createRandomizedDataset(
+                                networkName,
+                                supportedChannelMask,
+                                preferredChannelMask,
+                                Instant.now(),
+                                new Random(),
+                                new SecureRandom());
+
+                receiver.onSuccess(dataset);
+            }
+
+            @Override
+            public void onError(int errorCode, String errorMessage) {
+                receiver.onError(otErrorToAndroidError(errorCode), errorMessage);
+            }
+        };
+    }
+
+    private static ActiveOperationalDataset createRandomizedDataset(
             String networkName,
             int supportedChannelMask,
+            int preferredChannelMask,
             Instant now,
             Random random,
             SecureRandom secureRandom) {
@@ -622,6 +645,7 @@
 
         final SparseArray<byte[]> channelMask = new SparseArray<>(1);
         channelMask.put(CHANNEL_PAGE_24_GHZ, channelMaskToByteArray(supportedChannelMask));
+        final int channel = selectChannel(supportedChannelMask, preferredChannelMask, random);
 
         final byte[] securityFlags = new byte[] {(byte) 0xff, (byte) 0xf8};
 
@@ -632,7 +656,7 @@
                 .setExtendedPanId(newRandomBytes(random, LENGTH_EXTENDED_PAN_ID))
                 .setPanId(panId)
                 .setNetworkName(networkName)
-                .setChannel(CHANNEL_PAGE_24_GHZ, selectRandomChannel(supportedChannelMask, random))
+                .setChannel(CHANNEL_PAGE_24_GHZ, channel)
                 .setChannelMask(channelMask)
                 .setPskc(newRandomBytes(secureRandom, LENGTH_PSKC))
                 .setNetworkKey(newRandomBytes(secureRandom, LENGTH_NETWORK_KEY))
@@ -641,6 +665,18 @@
                 .build();
     }
 
+    private static int selectChannel(
+            int supportedChannelMask, int preferredChannelMask, Random random) {
+        // If the preferred channel mask is not empty, select a random channel from it, otherwise
+        // choose one from the supported channel mask.
+        preferredChannelMask = preferredChannelMask & supportedChannelMask;
+        if (preferredChannelMask == 0) {
+            preferredChannelMask = supportedChannelMask;
+        }
+
+        return selectRandomChannel(preferredChannelMask, random);
+    }
+
     private static byte[] newRandomBytes(Random random, int length) {
         byte[] result = new byte[length];
         random.nextBytes(result);
diff --git a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
index 1640679..b557e65 100644
--- a/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
+++ b/thread/tests/unit/src/com/android/server/thread/ThreadNetworkControllerServiceTest.java
@@ -16,6 +16,7 @@
 
 package com.android.server.thread;
 
+import static android.net.thread.ActiveOperationalDataset.CHANNEL_PAGE_24_GHZ;
 import static android.net.thread.ThreadNetworkController.STATE_DISABLED;
 import static android.net.thread.ThreadNetworkController.STATE_ENABLED;
 import static android.net.thread.ThreadNetworkException.ERROR_FAILED_PRECONDITION;
@@ -23,6 +24,7 @@
 import static android.net.thread.ThreadNetworkManager.DISALLOW_THREAD_NETWORK;
 import static android.net.thread.ThreadNetworkManager.PERMISSION_THREAD_NETWORK_PRIVILEGED;
 
+import static com.android.server.thread.openthread.IOtDaemon.ErrorCode.OT_ERROR_INVALID_STATE;
 import static com.android.testutils.TestPermissionUtil.runAsShell;
 
 import static com.google.common.io.BaseEncoding.base16;
@@ -30,8 +32,10 @@
 
 import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -47,9 +51,11 @@
 import android.net.NetworkAgent;
 import android.net.NetworkProvider;
 import android.net.thread.ActiveOperationalDataset;
+import android.net.thread.IActiveOperationalDatasetReceiver;
 import android.net.thread.IOperationReceiver;
 import android.net.thread.ThreadNetworkException;
 import android.os.Handler;
+import android.os.IBinder;
 import android.os.ParcelFileDescriptor;
 import android.os.RemoteException;
 import android.os.UserManager;
@@ -64,6 +70,8 @@
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 
@@ -95,6 +103,12 @@
                                     + "B9D351B40C0402A0FFF8");
     private static final ActiveOperationalDataset DEFAULT_ACTIVE_DATASET =
             ActiveOperationalDataset.fromThreadTlvs(DEFAULT_ACTIVE_DATASET_TLVS);
+    private static final String DEFAULT_NETWORK_NAME = "thread-wpan0";
+    private static final int OT_ERROR_NONE = 0;
+    private static final int DEFAULT_SUPPORTED_CHANNEL_MASK = 0x07FFF800; // from channel 11 to 26
+    private static final int DEFAULT_PREFERRED_CHANNEL_MASK = 0x00000800; // channel 11
+    private static final int DEFAULT_SELECTED_CHANNEL = 11;
+    private static final byte[] DEFAULT_SUPPORTED_CHANNEL_MASK_ARRAY = base16().decode("001FFFE0");
 
     @Mock private ConnectivityManager mMockConnectivityManager;
     @Mock private NetworkAgent mMockNetworkAgent;
@@ -104,10 +118,12 @@
     @Mock private ThreadPersistentSettings mMockPersistentSettings;
     @Mock private NsdPublisher mMockNsdPublisher;
     @Mock private UserManager mMockUserManager;
+    @Mock private IBinder mIBinder;
     private Context mContext;
     private TestLooper mTestLooper;
     private FakeOtDaemon mFakeOtDaemon;
     private ThreadNetworkControllerService mService;
+    @Captor private ArgumentCaptor<ActiveOperationalDataset> mActiveDatasetCaptor;
 
     @Before
     public void setUp() {
@@ -281,4 +297,40 @@
             }
         };
     }
+
+    @Test
+    public void createRandomizedDataset_succeed_activeDatasetCreated() throws Exception {
+        final IActiveOperationalDatasetReceiver mockReceiver =
+                mock(IActiveOperationalDatasetReceiver.class);
+        mFakeOtDaemon.setChannelMasks(
+                DEFAULT_SUPPORTED_CHANNEL_MASK, DEFAULT_PREFERRED_CHANNEL_MASK);
+        mFakeOtDaemon.setChannelMasksReceiverOtError(OT_ERROR_NONE);
+
+        mService.createRandomizedDataset(DEFAULT_NETWORK_NAME, mockReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, never()).onError(anyInt(), anyString());
+        verify(mockReceiver, times(1)).onSuccess(mActiveDatasetCaptor.capture());
+        ActiveOperationalDataset activeDataset = mActiveDatasetCaptor.getValue();
+        assertThat(activeDataset.getNetworkName()).isEqualTo(DEFAULT_NETWORK_NAME);
+        assertThat(activeDataset.getChannelMask().size()).isEqualTo(1);
+        assertThat(activeDataset.getChannelMask().get(CHANNEL_PAGE_24_GHZ))
+                .isEqualTo(DEFAULT_SUPPORTED_CHANNEL_MASK_ARRAY);
+        assertThat(activeDataset.getChannel()).isEqualTo(DEFAULT_SELECTED_CHANNEL);
+    }
+
+    @Test
+    public void createRandomizedDataset_otDaemonRemoteFailure_returnsPreconditionError()
+            throws Exception {
+        final IActiveOperationalDatasetReceiver mockReceiver =
+                mock(IActiveOperationalDatasetReceiver.class);
+        mFakeOtDaemon.setChannelMasksReceiverOtError(OT_ERROR_INVALID_STATE);
+        when(mockReceiver.asBinder()).thenReturn(mIBinder);
+
+        mService.createRandomizedDataset(DEFAULT_NETWORK_NAME, mockReceiver);
+        mTestLooper.dispatchAll();
+
+        verify(mockReceiver, never()).onSuccess(any(ActiveOperationalDataset.class));
+        verify(mockReceiver, times(1)).onError(eq(ERROR_INTERNAL_ERROR), anyString());
+    }
 }