Merge "Add sthal_cli"
diff --git a/audio/7.0/IDevice.hal b/audio/7.0/IDevice.hal
index e423f29..85c789a 100644
--- a/audio/7.0/IDevice.hal
+++ b/audio/7.0/IDevice.hal
@@ -245,6 +245,7 @@
/**
* Gets the HW synchronization source of the device. Calling this method is
* equivalent to getting AUDIO_PARAMETER_HW_AV_SYNC on the legacy HAL.
+ *
* Optional method
*
* @return retval operation completion status: OK or NOT_SUPPORTED.
@@ -255,6 +256,7 @@
/**
* Sets whether the screen is on. Calling this method is equivalent to
* setting AUDIO_PARAMETER_KEY_SCREEN_STATE on the legacy HAL.
+ *
* Optional method
*
* @param turnedOn whether the screen is turned on.
diff --git a/audio/7.0/IStream.hal b/audio/7.0/IStream.hal
index 393e38f..e4987c2 100644
--- a/audio/7.0/IStream.hal
+++ b/audio/7.0/IStream.hal
@@ -110,6 +110,7 @@
/**
* Return the set of devices which this stream is connected to.
+ *
* Optional method
*
* @return retval operation completion status: OK or NOT_SUPPORTED.
@@ -133,6 +134,7 @@
/**
* Sets the HW synchronization source. Calling this method is equivalent to
* setting AUDIO_PARAMETER_STREAM_HW_AV_SYNC on the legacy HAL.
+ *
* Optional method
*
* @param hwAvSync HW synchronization source
diff --git a/audio/7.0/IStreamIn.hal b/audio/7.0/IStreamIn.hal
index bf9ae52..be4bda4 100644
--- a/audio/7.0/IStreamIn.hal
+++ b/audio/7.0/IStreamIn.hal
@@ -24,6 +24,7 @@
* Returns the source descriptor of the input stream. Calling this method is
* equivalent to getting AUDIO_PARAMETER_STREAM_INPUT_SOURCE on the legacy
* HAL.
+ *
* Optional method
*
* @return retval operation completion status.
@@ -33,6 +34,7 @@
/**
* Set the input gain for the audio driver.
+ *
* Optional method
*
* @param gain 1.0f is unity, 0.0f is zero.
@@ -42,6 +44,7 @@
/**
* Called when the metadata of the stream's sink has been changed.
+ *
* Optional method
*
* @param sinkMetadata Description of the audio that is suggested by the clients.
@@ -148,7 +151,8 @@
/**
* Return a recent count of the number of audio frames received and the
- * clock time associated with that frame count.
+ * clock time associated with that frame count. The count must not reset to
+ * zero when a PCM input enters standby.
*
* @return retval INVALID_STATE if the device is not ready/available,
* NOT_SUPPORTED if the command is not supported,
diff --git a/audio/7.0/IStreamOut.hal b/audio/7.0/IStreamOut.hal
index 78cb51b..6e8498e 100644
--- a/audio/7.0/IStreamOut.hal
+++ b/audio/7.0/IStreamOut.hal
@@ -35,6 +35,7 @@
* allowing to directly set the volume as apposed to via the framework.
* This method might produce multiple PCM outputs or hardware accelerated
* codecs, such as MP3 or AAC.
+ *
* Optional method
*
* @param left left channel attenuation, 1.0f is unity, 0.0f is zero.
@@ -46,6 +47,7 @@
/**
* Called when the metadata of the stream's source has been changed.
+ *
* Optional method
*
* @param sourceMetadata Description of the audio that is played by the clients.
@@ -130,6 +132,7 @@
/**
* Return the number of audio frames written by the audio DSP to DAC since
* the output has exited standby.
+ *
* Optional method
*
* @return retval operation completion status.
@@ -141,6 +144,7 @@
* Get the local time at which the next write to the audio driver will be
* presented. The units are microseconds, where the epoch is decided by the
* local audio HAL.
+ *
* Optional method
*
* @return retval operation completion status.
@@ -253,8 +257,11 @@
drain(AudioDrain type) generates (Result retval);
/**
- * Notifies to the audio driver to flush the queued data. Stream must
- * already be paused before calling 'flush'.
+ * Notifies to the audio driver to flush (that is, drop) the queued
+ * data. Stream must already be paused before calling 'flush'. For
+ * compressed and offload streams the frame count returned by
+ * 'getPresentationPosition' must reset after flush.
+ *
* Optional method
*
* Implementation of this function is mandatory for offloaded playback.
@@ -266,12 +273,14 @@
/**
* Return a recent count of the number of audio frames presented to an
* external observer. This excludes frames which have been written but are
- * still in the pipeline. The count is not reset to zero when output enters
- * standby. Also returns the value of CLOCK_MONOTONIC as of this
+ * still in the pipeline. The count must not reset to zero when a PCM output
+ * enters standby. For compressed and offload streams it is recommended that
+ * HAL resets the frame count.
+ *
+ * This method also returns the value of CLOCK_MONOTONIC as of this
* presentation count. The returned count is expected to be 'recent', but
* does not need to be the most recent possible value. However, the
* associated time must correspond to whatever count is returned.
- *
* Example: assume that N+M frames have been presented, where M is a 'small'
* number. Then it is permissible to return N instead of N+M, and the
* timestamp must correspond to N rather than N+M. The terms 'recent' and
@@ -287,6 +296,7 @@
/**
* Selects a presentation for decoding from a next generation media stream
* (as defined per ETSI TS 103 190-2) and a program within the presentation.
+ *
* Optional method
*
* @param presentationId selected audio presentation.
diff --git a/audio/common/7.0/types.hal b/audio/common/7.0/types.hal
index bea0705..4f920e4 100644
--- a/audio/common/7.0/types.hal
+++ b/audio/common/7.0/types.hal
@@ -344,7 +344,7 @@
DeviceAddress device;
} destination;
AudioChannelMask channelMask;
- /** Tags from AudioTrack audio atttributes */
+ /** Tags from AudioRecord audio atttributes */
vec<AudioTag> tags;
};
diff --git a/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.cpp b/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.cpp
index 548285a..9be9ea7 100644
--- a/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.cpp
+++ b/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.cpp
@@ -31,6 +31,14 @@
GeneratorHub::GeneratorHub(const OnHalEvent& onHalEvent)
: mOnHalEvent(onHalEvent), mThread(&GeneratorHub::run, this) {}
+GeneratorHub::~GeneratorHub() {
+ mShuttingDownFlag.store(true);
+ mCond.notify_all();
+ if (mThread.joinable()) {
+ mThread.join();
+ }
+}
+
void GeneratorHub::registerGenerator(int32_t cookie, FakeValueGeneratorPtr generator) {
{
std::lock_guard<std::mutex> g(mLock);
@@ -58,15 +66,18 @@
}
void GeneratorHub::run() {
- while (true) {
+ while (!mShuttingDownFlag.load()) {
std::unique_lock<std::mutex> g(mLock);
// Pop events whose generator does not exist (may be already unregistered)
while (!mEventQueue.empty()
&& mGenerators.find(mEventQueue.top().cookie) == mGenerators.end()) {
mEventQueue.pop();
}
- // Wait until event queue is not empty
- mCond.wait(g, [this] { return !mEventQueue.empty(); });
+ // Wait until event queue is not empty or shutting down flag is set
+ mCond.wait(g, [this] { return !mEventQueue.empty() || mShuttingDownFlag.load(); });
+ if (mShuttingDownFlag.load()) {
+ break;
+ }
const VhalEvent& curEvent = mEventQueue.top();
diff --git a/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.h b/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.h
index dcf6a4f..b25dbf1 100644
--- a/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.h
+++ b/automotive/vehicle/2.0/default/impl/vhal_v2_0/GeneratorHub.h
@@ -58,7 +58,7 @@
public:
GeneratorHub(const OnHalEvent& onHalEvent);
- ~GeneratorHub() = default;
+ ~GeneratorHub();
/**
* Register a new generator. The generator will be discarded if it could not produce next event.
@@ -84,6 +84,7 @@
mutable std::mutex mLock;
std::condition_variable mCond;
std::thread mThread;
+ std::atomic<bool> mShuttingDownFlag{false};
};
} // namespace impl
diff --git a/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/AcquiredInfo.aidl b/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/AcquiredInfo.aidl
index 2600e61..c19534c 100644
--- a/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/AcquiredInfo.aidl
+++ b/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/AcquiredInfo.aidl
@@ -59,7 +59,5 @@
VENDOR = 22,
FIRST_FRAME_RECEIVED = 23,
DARK_GLASSES_DETECTED = 24,
- FACE_COVERING_DETECTED = 25,
- EYES_NOT_VISIBLE = 26,
- MOUTH_NOT_VISIBLE = 27,
+ MOUTH_COVERING_DETECTED = 25,
}
diff --git a/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/IFace.aidl b/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/IFace.aidl
index 0d1ef45..fc4a4d0 100644
--- a/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/IFace.aidl
+++ b/biometrics/face/aidl/aidl_api/android.hardware.biometrics.face/current/android/hardware/biometrics/face/IFace.aidl
@@ -36,5 +36,4 @@
interface IFace {
android.hardware.biometrics.face.SensorProps[] getSensorProps();
android.hardware.biometrics.face.ISession createSession(in int sensorId, in int userId, in android.hardware.biometrics.face.ISessionCallback cb);
- void reset();
}
diff --git a/biometrics/face/aidl/android/hardware/biometrics/face/AcquiredInfo.aidl b/biometrics/face/aidl/android/hardware/biometrics/face/AcquiredInfo.aidl
index 217a9bb..a3b229e 100644
--- a/biometrics/face/aidl/android/hardware/biometrics/face/AcquiredInfo.aidl
+++ b/biometrics/face/aidl/android/hardware/biometrics/face/AcquiredInfo.aidl
@@ -187,7 +187,7 @@
*/
ROLL_TOO_EXTREME = 18,
- /**
+ /**
* The user’s face has been obscured by some object.
*
* The user should be informed to remove any objects from the line of sight from
@@ -230,18 +230,5 @@
* A face mask or face covering detected. This can be useful for providing relevant feedback to
* the user and enabling an alternative authentication logic if the implementation supports it.
*/
- FACE_COVERING_DETECTED = 25,
-
- /**
- * Either one or both eyes are not visible in the frame. Prefer to use DARK_GLASSES_DETECTED if
- * the eyes are not visible due to dark glasses.
- */
- EYES_NOT_VISIBLE = 26,
-
- /**
- * The mouth is not visible in the frame. Prefer to use MASK_DETECTED if the mouth is not
- * visible due to a mask.
- */
- MOUTH_NOT_VISIBLE = 27,
+ MOUTH_COVERING_DETECTED = 25,
}
-
diff --git a/biometrics/face/aidl/android/hardware/biometrics/face/IFace.aidl b/biometrics/face/aidl/android/hardware/biometrics/face/IFace.aidl
index afb7c8d..11cdf77 100644
--- a/biometrics/face/aidl/android/hardware/biometrics/face/IFace.aidl
+++ b/biometrics/face/aidl/android/hardware/biometrics/face/IFace.aidl
@@ -50,14 +50,4 @@
* @return A new session.
*/
ISession createSession(in int sensorId, in int userId, in ISessionCallback cb);
-
- /**
- * Resets the HAL into a clean state, forcing it to cancel all of the pending operations, close
- * its current session, and release all of the acquired resources.
- *
- * This should be used as a last resort to recover the HAL if the current session becomes
- * unresponsive. The implementation might choose to restart the HAL process to get back into a
- * good state.
- */
- void reset();
}
diff --git a/biometrics/face/aidl/default/Face.cpp b/biometrics/face/aidl/default/Face.cpp
index 73e50f3..a4520de 100644
--- a/biometrics/face/aidl/default/Face.cpp
+++ b/biometrics/face/aidl/default/Face.cpp
@@ -63,8 +63,4 @@
return ndk::ScopedAStatus::ok();
}
-ndk::ScopedAStatus Face::reset() {
- return ndk::ScopedAStatus::ok();
-}
-
} // namespace aidl::android::hardware::biometrics::face
diff --git a/biometrics/face/aidl/default/Face.h b/biometrics/face/aidl/default/Face.h
index 809b856..786b4f8 100644
--- a/biometrics/face/aidl/default/Face.h
+++ b/biometrics/face/aidl/default/Face.h
@@ -27,8 +27,6 @@
ndk::ScopedAStatus createSession(int32_t sensorId, int32_t userId,
const std::shared_ptr<ISessionCallback>& cb,
std::shared_ptr<ISession>* _aidl_return) override;
-
- ndk::ScopedAStatus reset() override;
};
} // namespace aidl::android::hardware::biometrics::face
diff --git a/biometrics/fingerprint/aidl/aidl_api/android.hardware.biometrics.fingerprint/current/android/hardware/biometrics/fingerprint/IFingerprint.aidl b/biometrics/fingerprint/aidl/aidl_api/android.hardware.biometrics.fingerprint/current/android/hardware/biometrics/fingerprint/IFingerprint.aidl
index 07777c9..5d3df6f 100644
--- a/biometrics/fingerprint/aidl/aidl_api/android.hardware.biometrics.fingerprint/current/android/hardware/biometrics/fingerprint/IFingerprint.aidl
+++ b/biometrics/fingerprint/aidl/aidl_api/android.hardware.biometrics.fingerprint/current/android/hardware/biometrics/fingerprint/IFingerprint.aidl
@@ -36,5 +36,4 @@
interface IFingerprint {
android.hardware.biometrics.fingerprint.SensorProps[] getSensorProps();
android.hardware.biometrics.fingerprint.ISession createSession(in int sensorId, in int userId, in android.hardware.biometrics.fingerprint.ISessionCallback cb);
- void reset();
}
diff --git a/biometrics/fingerprint/aidl/android/hardware/biometrics/fingerprint/IFingerprint.aidl b/biometrics/fingerprint/aidl/android/hardware/biometrics/fingerprint/IFingerprint.aidl
index 37062ba..98a4530 100644
--- a/biometrics/fingerprint/aidl/android/hardware/biometrics/fingerprint/IFingerprint.aidl
+++ b/biometrics/fingerprint/aidl/android/hardware/biometrics/fingerprint/IFingerprint.aidl
@@ -65,14 +65,4 @@
* @return A new session
*/
ISession createSession(in int sensorId, in int userId, in ISessionCallback cb);
-
- /**
- * Resets the HAL into a clean state, forcing it to cancel all of the pending operations, close
- * its current session, and release all of the acquired resources.
- *
- * This should be used as a last resort to recover the HAL if the current session becomes
- * unresponsive. The implementation might choose to restart the HAL process to get back into a
- * good state.
- */
- void reset();
}
diff --git a/biometrics/fingerprint/aidl/default/Fingerprint.cpp b/biometrics/fingerprint/aidl/default/Fingerprint.cpp
index 206f518..79f48fe 100644
--- a/biometrics/fingerprint/aidl/default/Fingerprint.cpp
+++ b/biometrics/fingerprint/aidl/default/Fingerprint.cpp
@@ -63,10 +63,4 @@
return ndk::ScopedAStatus::ok();
}
-ndk::ScopedAStatus Fingerprint::reset() {
- // Crash. The system will start a fresh instance of the HAL.
- CHECK(false) << "Unable to reset. Crashing.";
- return ndk::ScopedAStatus::ok();
-}
-
} // namespace aidl::android::hardware::biometrics::fingerprint
diff --git a/biometrics/fingerprint/aidl/default/Session.cpp b/biometrics/fingerprint/aidl/default/Session.cpp
index 9e6ac77..c035407 100644
--- a/biometrics/fingerprint/aidl/default/Session.cpp
+++ b/biometrics/fingerprint/aidl/default/Session.cpp
@@ -46,7 +46,7 @@
void Session::enterStateOrCrash(int cookie, SessionState state) {
CHECK(mScheduledState == state);
- mCurrentState = mScheduledState;
+ mCurrentState = state;
mScheduledState = SessionState::IDLING;
mCb->onStateChanged(cookie, mCurrentState);
}
diff --git a/biometrics/fingerprint/aidl/default/include/Fingerprint.h b/biometrics/fingerprint/aidl/default/include/Fingerprint.h
index 9b43419..7bd3d6d 100644
--- a/biometrics/fingerprint/aidl/default/include/Fingerprint.h
+++ b/biometrics/fingerprint/aidl/default/include/Fingerprint.h
@@ -34,8 +34,6 @@
const std::shared_ptr<ISessionCallback>& cb,
std::shared_ptr<ISession>* out) override;
- ndk::ScopedAStatus reset() override;
-
private:
std::unique_ptr<FakeFingerprintEngine> mEngine;
WorkerThread mWorker;
diff --git a/biometrics/fingerprint/aidl/default/include/Session.h b/biometrics/fingerprint/aidl/default/include/Session.h
index d2f0c19..97d5645 100644
--- a/biometrics/fingerprint/aidl/default/include/Session.h
+++ b/biometrics/fingerprint/aidl/default/include/Session.h
@@ -82,13 +82,28 @@
// by calling ISessionCallback#onStateChanged.
void enterIdling(int cookie);
+ // The sensor and user IDs for which this session was created.
int32_t mSensorId;
int32_t mUserId;
+
+ // Callback for talking to the framework. This callback must only be called from non-binder
+ // threads to prevent nested binder calls and consequently a binder thread exhaustion.
+ // Practically, it means that this callback should always be called from the worker thread.
std::shared_ptr<ISessionCallback> mCb;
+
+ // Module that communicates to the actual fingerprint hardware, keystore, TEE, etc. In real
+ // life such modules typically consume a lot of memory and are slow to initialize. This is here
+ // to showcase how such a module can be used within a Session without incurring the high
+ // initialization costs every time a Session is constructed.
FakeFingerprintEngine* mEngine;
+
+ // Worker thread that allows to schedule tasks for asynchronous execution.
WorkerThread* mWorker;
- SessionState mScheduledState;
- SessionState mCurrentState;
+
+ // Simple representation of the session's state machine. These are atomic because they can be
+ // modified from both the main and the worker threads.
+ std::atomic<SessionState> mScheduledState;
+ std::atomic<SessionState> mCurrentState;
};
} // namespace aidl::android::hardware::biometrics::fingerprint
diff --git a/broadcastradio/2.0/vts/functional/VtsHalBroadcastradioV2_0TargetTest.cpp b/broadcastradio/2.0/vts/functional/VtsHalBroadcastradioV2_0TargetTest.cpp
index 5ba7a76..362ab41 100644
--- a/broadcastradio/2.0/vts/functional/VtsHalBroadcastradioV2_0TargetTest.cpp
+++ b/broadcastradio/2.0/vts/functional/VtsHalBroadcastradioV2_0TargetTest.cpp
@@ -495,7 +495,7 @@
* invoked carrying a proper selector;
* - program changes exactly to what was requested.
*/
-TEST_F(BroadcastRadioHalTest, DabTune) {
+TEST_P(BroadcastRadioHalTest, DabTune) {
ASSERT_TRUE(openSession());
ProgramSelector sel = {};
diff --git a/compatibility_matrices/compatibility_matrix.5.xml b/compatibility_matrices/compatibility_matrix.5.xml
index 96a3692..8e175f0 100644
--- a/compatibility_matrices/compatibility_matrix.5.xml
+++ b/compatibility_matrices/compatibility_matrix.5.xml
@@ -256,6 +256,13 @@
</hal>
<hal format="aidl" optional="true">
<name>android.hardware.identity</name>
+ <!--
+ b/178458001: identity V2 is introduced in R, but Android R VINTF does not support AIDL
+ versions. Hence, we only specify identity V2 in compatibility_matrix.5.xml in Android S+
+ branches. In Android R branches, the matrix implicitly specifies V1.
+ SingleManifestTest.ManifestAidlHalsServed has an exemption for this.
+ -->
+ <version>1-2</version>
<interface>
<name>IIdentityCredentialStore</name>
<instance>default</instance>
diff --git a/current.txt b/current.txt
index af50841..6c576ca 100644
--- a/current.txt
+++ b/current.txt
@@ -780,8 +780,8 @@
dabe23dde7c9e3ad65c61def7392f186d7efe7f4216f9b6f9cf0863745b1a9f4 android.hardware.keymaster@4.1::IKeymasterDevice
cd84ab19c590e0e73dd2307b591a3093ee18147ef95e6d5418644463a6620076 android.hardware.neuralnetworks@1.2::IDevice
f729ee6a5f136b25d79ea6895d24700fce413df555baaecf2c39e4440d15d043 android.hardware.neuralnetworks@1.0::types
-ba84f3a750b1cc43ac51074e8b8e22df924f3e6d9068fac50d95bcf57b2b1d61 android.hardware.neuralnetworks@1.2::types
-9fe5a4093043c2b5da4e9491aed1646c388a5d3059b8fd77d5b6a9807e6d3a3e android.hardware.neuralnetworks@1.3::types
+a84f8dac7a9b75de1cc2936a9b429b9b62b32a31ea88ca52c29f98f5ddc0fa95 android.hardware.neuralnetworks@1.2::types
+cd331b92312d16ab89f475c39296abbf539efc4114a8c5c2b136ad99b904ef33 android.hardware.neuralnetworks@1.3::types
e8c86c69c438da8d1549856c1bb3e2d1b8da52722f8235ff49a30f2cce91742c android.hardware.soundtrigger@2.1::ISoundTriggerHwCallback
b9fbb6e2e061ed0960939d48b785e9700210add1f13ed32ecd688d0f1ca20ef7 android.hardware.renderscript@1.0::types
0f53d70e1eadf8d987766db4bf6ae2048004682168f4cab118da576787def3fa android.hardware.radio@1.0::types
diff --git a/gnss/aidl/aidl_api/android.hardware.gnss/current/android/hardware/gnss/CorrelationVector.aidl b/gnss/aidl/aidl_api/android.hardware.gnss/current/android/hardware/gnss/CorrelationVector.aidl
index 2d21748..9c9a241 100644
--- a/gnss/aidl/aidl_api/android.hardware.gnss/current/android/hardware/gnss/CorrelationVector.aidl
+++ b/gnss/aidl/aidl_api/android.hardware.gnss/current/android/hardware/gnss/CorrelationVector.aidl
@@ -33,7 +33,7 @@
package android.hardware.gnss;
@VintfStability
parcelable CorrelationVector {
- int frequencyOffsetMps;
+ double frequencyOffsetMps;
double samplingWidthM;
double samplingStartM;
int[] magnitude;
diff --git a/gnss/aidl/android/hardware/gnss/CorrelationVector.aidl b/gnss/aidl/android/hardware/gnss/CorrelationVector.aidl
index 22a80ce..6fbabbc 100644
--- a/gnss/aidl/android/hardware/gnss/CorrelationVector.aidl
+++ b/gnss/aidl/android/hardware/gnss/CorrelationVector.aidl
@@ -22,11 +22,10 @@
*/
@VintfStability
parcelable CorrelationVector {
-
/**
* Frequency offset from reported pseudorange rate for this Correlation Vector.
*/
- int frequencyOffsetMps;
+ double frequencyOffsetMps;
/**
* Space between correlation samples in meters.
@@ -48,4 +47,4 @@
* The length of the array is defined by the GNSS chipset.
*/
int[] magnitude;
-}
\ No newline at end of file
+}
diff --git a/identity/aidl/default/common/IdentityCredential.cpp b/identity/aidl/default/common/IdentityCredential.cpp
index 9477997..c8ee0dd 100644
--- a/identity/aidl/default/common/IdentityCredential.cpp
+++ b/identity/aidl/default/common/IdentityCredential.cpp
@@ -253,14 +253,17 @@
}
}
- // Feed the auth token to secure hardware.
- if (!hwProxy_->setAuthToken(authToken.challenge, authToken.userId, authToken.authenticatorId,
- int(authToken.authenticatorType), authToken.timestamp.milliSeconds,
- authToken.mac, verificationToken_.challenge,
- verificationToken_.timestamp.milliSeconds,
- int(verificationToken_.securityLevel), verificationToken_.mac)) {
- return ndk::ScopedAStatus(AStatus_fromServiceSpecificErrorWithMessage(
- IIdentityCredentialStore::STATUS_INVALID_DATA, "Invalid Auth Token"));
+ // Feed the auth token to secure hardware only if they're valid.
+ if (authToken.timestamp.milliSeconds != 0) {
+ if (!hwProxy_->setAuthToken(
+ authToken.challenge, authToken.userId, authToken.authenticatorId,
+ int(authToken.authenticatorType), authToken.timestamp.milliSeconds,
+ authToken.mac, verificationToken_.challenge,
+ verificationToken_.timestamp.milliSeconds,
+ int(verificationToken_.securityLevel), verificationToken_.mac)) {
+ return ndk::ScopedAStatus(AStatus_fromServiceSpecificErrorWithMessage(
+ IIdentityCredentialStore::STATUS_INVALID_DATA, "Invalid Auth Token"));
+ }
}
// We'll be feeding ACPs interleaved with certificates from the reader
diff --git a/identity/aidl/default/libeic/EicPresentation.c b/identity/aidl/default/libeic/EicPresentation.c
index 5e9a280..9e033b3 100644
--- a/identity/aidl/default/libeic/EicPresentation.c
+++ b/identity/aidl/default/libeic/EicPresentation.c
@@ -336,6 +336,18 @@
int verificationTokenSecurityLevel,
const uint8_t* verificationTokenMac,
size_t verificationTokenMacSize) {
+ // It doesn't make sense to accept any tokens if eicPresentationCreateAuthChallenge()
+ // was never called.
+ if (ctx->authChallenge == 0) {
+ eicDebug("Trying validate tokens when no auth-challenge was previously generated");
+ return false;
+ }
+ // At least the verification-token must have the same challenge as what was generated.
+ if (verificationTokenChallenge != ctx->authChallenge) {
+ eicDebug("Challenge in verification token does not match the challenge "
+ "previously generated");
+ return false;
+ }
if (!eicOpsValidateAuthToken(
challenge, secureUserId, authenticatorId, hardwareAuthenticatorType, timeStamp, mac,
macSize, verificationTokenChallenge, verificationTokenTimestamp,
@@ -360,18 +372,9 @@
return false;
}
+ // Only ACP with auth-on-every-presentation - those with timeout == 0 - need the
+ // challenge to match...
if (timeoutMillis == 0) {
- if (ctx->authTokenChallenge == 0) {
- eicDebug("No challenge in authToken");
- return false;
- }
-
- // If we didn't create a challenge, too bad but user auth with
- // timeoutMillis set to 0 needs it.
- if (ctx->authChallenge == 0) {
- eicDebug("No challenge was created for this session");
- return false;
- }
if (ctx->authTokenChallenge != ctx->authChallenge) {
eicDebug("Challenge in authToken (%" PRIu64
") doesn't match the challenge "
diff --git a/keymaster/4.0/vts/functional/keymaster_hidl_hal_test.cpp b/keymaster/4.0/vts/functional/keymaster_hidl_hal_test.cpp
index 5f81394..e0d60fc 100644
--- a/keymaster/4.0/vts/functional/keymaster_hidl_hal_test.cpp
+++ b/keymaster/4.0/vts/functional/keymaster_hidl_hal_test.cpp
@@ -136,48 +136,53 @@
return retval;
}
-string rsa_2048_key =
- hex2str("308204a50201000282010100caa620db7bbadfd351153a804e05a3115a0"
- "eea067316c7d6ae010086cc4d636edcc50b725c495027e79d7c6d65ec50"
- "5ab84107b0ca9f8389d0d812d42df3af0c1c50f1083b1eedd18921283e3"
- "9ebe95bd56795c9ba129afc63d60fb020b300c44861a73845508a992c54"
- "7cf4ce7694955c684bc130fe9a0478285d686da954989a7be3cd970de7e"
- "5eca8574c0617fed74717f7035655f65af7b5f9b982feca8eed643b96d8"
- "f1c4e6dcd96a9ccfcca3366d8f1c95f83a83ab785f997b78918ceca567d"
- "91cf2ea85c340c0d4462f31f8a31e648cd26e1116a97d17dcfec51e4336"
- "fa0725ff49216005911966748f94789c055795da023362091c977bdc0bd"
- "8e31902030100010282010100ca562da0785e1275d013be21b5c5731834"
- "2f8803808e52624bc2bc5fdb45b9ee4b8882f160abe2d8b52e4dba7d760"
- "295523bbc0e0d824fb81f4a5f2273ef47ec73a96dc0a6272f9573b22398"
- "5e04eb2fc25876fac04b2b6cadd2623f9da69d315e84028ef0c6865c822"
- "2a9d15504993eb8d17a321f55573af72e76757a690408c36909eb44a555"
- "4b571007edde150b47952287d942559e7f8cbcb2c47086aa291515f55c4"
- "deba6d1ebde0cca5ee899b3b0c4c21123bbf92feac53db515fe02d03b83"
- "2154e31122abcbb6fc80b49e1c8fc5528605935f8f6ead1237b16e83d23"
- "ad73e82ee008c3ff7b4666f4c137c20f52ae6fea5b54ed104c1c1bf75fc"
- "3c020102818100efa6b29bb0f6b81c8fecf3e73c3e5a59b71ffd31075c4"
- "0282269ee245367c2e54f0244301dad0b90dcce73f25c1caca2f4ef1774"
- "42a5d9e98a354bcd5ddae129bea2c0771d1ad51341f44ddf0c5c0f22252"
- "414e2de7af6c67754dba610ee2743f21789a89829ad91efc02c7c5588fe"
- "84b64df12dc5cee90df2e7dd4a1ca2886902818100d87937f039df50054"
- "7c7d5435ec8e89789b36a0e5c4004d4612a6ef2dce39ee4f24fb5d2da38"
- "dbf5f3d639681a11fc416618554b1ff51a8215446b676363f6a5e91ea6c"
- "957483e0a47ae36582bde9fba45c00e6e3fadc651cc87c170171d7fef6d"
- "0dc1f0ddb6eca2674064925b78542b32f2821605c29b6d0b65485081f5a"
- "f3102818100ee21453ee153f6d422cb7ffc586758dde6d239835b5df63e"
- "2b1bf94f4d35407b1ccc12b780f56f15ade2d36192d7c74f5174b66886c"
- "5484800563f113cde7e783d7e7922a2e003b3d4088ecc40fac4ead7df07"
- "85fb2e524219574fbeaefa063844b9d0c69f1462ed2d3f56b4e145742aa"
- "8ffbfd40cc731daf37023fa3d83df6902818055dc2e8dbfc68d2caafddd"
- "deacd7af397bca87c44e5eae0bb6c667df3831a83252d1bee274df9c8ef"
- "f39f6e70d8018b7afd0f2f3ab27426e5a151b2c94c56f6cfafbc75790a0"
- "fcca8307dc5238844282556c09cd3cc0a62a879f48e036aae2b58a61ac8"
- "ce6c3c933d914374fbdac0a665ffcc4100c14d624f82221fe9cad5fe102"
- "818100964193ee55581c9a82fe03f8eb018cdce8965f30745cc6e68154c"
- "b6618ef3cc57ae4798ff2a509306a135f7cf705ceb215fda6939c7a6353"
- "0c86a5ba02f491a64f6079e62b1b00b86859899febf3ed300edcc0b8b35"
- "1855a90d9d39a279be963f0972a256084a3c46575f796ad27dc801f67a3"
- "7a59e62e076b996f025a9c9042");
+/*
+ * DER-encoded PKCS#8 format RSA key. Generated using:
+ *
+ * openssl genrsa 2048 | openssl pkcs8 -topk8 -nocrypt -outform der | hexdump -e '30/1 "%02X" "\n"'
+ */
+string rsa_2048_key = hex2str(
+ "308204BD020100300D06092A864886F70D0101010500048204A7308204A3"
+ "0201000282010100BEBC342B56D443B1299F9A6A7056E80A897E318476A5"
+ "A18029E63B2ED739A61791D339F58DC763D9D14911F2EDEC383DEE11F631"
+ "9B44510E7A3ECD9B79B97382E49500ACF8117DC89CAF0E621F77756554A2"
+ "FD4664BFE7AB8B59AB48340DBFA27B93B5A81F6ECDEB02D0759307128DF3"
+ "E3BAD4055C8B840216DFAA5700670E6C5126F0962FCB70FF308F25049164"
+ "CCF76CC2DA66A7DD9A81A714C2809D69186133D29D84568E892B6FFBF319"
+ "9BDB14383EE224407F190358F111A949552ABA6714227D1BD7F6B20DD0CB"
+ "88F9467B719339F33BFF35B3870B3F62204E4286B0948EA348B524544B5F"
+ "9838F29EE643B079EEF8A713B220D7806924CDF7295070C5020301000102"
+ "82010069F377F35F2F584EF075353CCD1CA99738DB3DBC7C7FF35F9366CE"
+ "176DFD1B135AB10030344ABF5FBECF1D4659FDEF1C0FC430834BE1BE3911"
+ "951377BB3D563A2EA9CA8F4AD9C48A8CE6FD516A735C662686C7B4B3C09A"
+ "7B8354133E6F93F790D59EAEB92E84C9A4339302CCE28FDF04CCCAFA7DE3"
+ "F3A827D4F6F7D38E68B0EC6AB706645BF074A4E4090D06FB163124365FD5"
+ "EE7A20D350E9958CC30D91326E1B292E9EF5DB408EC42DAF737D20149704"
+ "D0A678A0FB5B5446863B099228A352D604BA8091A164D01D5AB05397C71E"
+ "AD20BE2A08FC528FE442817809C787FEE4AB97F97B9130D022153EDC6EB6"
+ "CBE7B0F8E3473F2E901209B5DB10F93604DB0102818100E83C0998214941"
+ "EA4F9293F1B77E2E99E6CF305FAF358238E126124FEAF2EB9724B2EA7B78"
+ "E6032343821A80E55D1D88FB12D220C3F41A56142FEC85796D1917F1E8C7"
+ "74F142B67D3D6E7B7E6B4383E94DB5929089DBB346D5BDAB40CC2D96EE04"
+ "09475E175C63BF78CFD744136740838127EA723FF3FE7FA368C1311B4A4E"
+ "0502818100D240FCC0F5D7715CDE21CB2DC86EA146132EA3B06F61FF2AF5"
+ "4BF38473F59DADCCE32B5F4CC32DD0BA6F509347B4B5B1B58C39F95E4798"
+ "CCBB43E83D0119ACF532F359CA743C85199F0286610E200997D731291717"
+ "9AC9B67558773212EC961E8BCE7A3CC809BC5486A96E4B0E6AF394D94E06"
+ "6A0900B7B70E82A44FB30053C102818100AD15DA1CBD6A492B66851BA8C3"
+ "16D38AB700E2CFDDD926A658003513C54BAA152B30021D667D20078F500F"
+ "8AD3E7F3945D74A891ED1A28EAD0FEEAEC8C14A8E834CF46A13D1378C99D"
+ "18940823CFDD27EC5810D59339E0C34198AC638E09C87CBB1B634A9864AE"
+ "9F4D5EB2D53514F67B4CAEC048C8AB849A02E397618F3271350281801FA2"
+ "C1A5331880A92D8F3E281C617108BF38244F16E352E69ED417C7153F9EC3"
+ "18F211839C643DCF8B4DD67CE2AC312E95178D5D952F06B1BF779F491692"
+ "4B70F582A23F11304E02A5E7565AE22A35E74FECC8B6FDC93F92A1A37703"
+ "E4CF0E63783BD02EB716A7ECBBFA606B10B74D01579522E7EF84D91FC522"
+ "292108D902C1028180796FE3825F9DCC85DF22D58690065D93898ACD65C0"
+ "87BEA8DA3A63BF4549B795E2CD0E3BE08CDEBD9FCF1720D9CDC5070D74F4"
+ "0DED8E1102C52152A31B6165F83A6722AECFCC35A493D7634664B888A08D"
+ "3EB034F12EA28BFEE346E205D334827F778B16ED40872BD29FCB36536B6E"
+ "93FFB06778696B4A9D81BB0A9423E63DE5");
string rsa_key = hex2str(
"30820275020100300d06092a864886f70d01010105000482025f3082025b"
diff --git a/neuralnetworks/1.2/types.hal b/neuralnetworks/1.2/types.hal
index e3cee93..03aed86 100644
--- a/neuralnetworks/1.2/types.hal
+++ b/neuralnetworks/1.2/types.hal
@@ -4895,25 +4895,25 @@
* Additional parameters specific to a particular operand type.
*/
safe_union ExtraParams {
- /**
- * No additional parameters.
- */
- Monostate none;
+ /**
+ * No additional parameters.
+ */
+ Monostate none;
- /**
- * Symmetric per-channel quantization parameters.
- *
- * Only applicable to operands of type TENSOR_QUANT8_SYMM_PER_CHANNEL.
- */
- SymmPerChannelQuantParams channelQuant;
+ /**
+ * Symmetric per-channel quantization parameters.
+ *
+ * Only applicable to operands of type TENSOR_QUANT8_SYMM_PER_CHANNEL.
+ */
+ SymmPerChannelQuantParams channelQuant;
- /**
- * Extension operand parameters.
- *
- * The framework treats this as an opaque data blob.
- * The format is up to individual extensions.
- */
- vec<uint8_t> extension;
+ /**
+ * Extension operand parameters.
+ *
+ * The framework treats this as an opaque data blob.
+ * The format is up to individual extensions.
+ */
+ vec<uint8_t> extension;
} extraParams;
};
diff --git a/neuralnetworks/1.2/types.t b/neuralnetworks/1.2/types.t
index 054d516..4c9fd02 100644
--- a/neuralnetworks/1.2/types.t
+++ b/neuralnetworks/1.2/types.t
@@ -291,25 +291,25 @@
* Additional parameters specific to a particular operand type.
*/
safe_union ExtraParams {
- /**
- * No additional parameters.
- */
- Monostate none;
+ /**
+ * No additional parameters.
+ */
+ Monostate none;
- /**
- * Symmetric per-channel quantization parameters.
- *
- * Only applicable to operands of type TENSOR_QUANT8_SYMM_PER_CHANNEL.
- */
- SymmPerChannelQuantParams channelQuant;
+ /**
+ * Symmetric per-channel quantization parameters.
+ *
+ * Only applicable to operands of type TENSOR_QUANT8_SYMM_PER_CHANNEL.
+ */
+ SymmPerChannelQuantParams channelQuant;
- /**
- * Extension operand parameters.
- *
- * The framework treats this as an opaque data blob.
- * The format is up to individual extensions.
- */
- vec<uint8_t> extension;
+ /**
+ * Extension operand parameters.
+ *
+ * The framework treats this as an opaque data blob.
+ * The format is up to individual extensions.
+ */
+ vec<uint8_t> extension;
} extraParams;
};
diff --git a/neuralnetworks/1.2/utils/test/DeviceTest.cpp b/neuralnetworks/1.2/utils/test/DeviceTest.cpp
index 9c8adde..215d44c 100644
--- a/neuralnetworks/1.2/utils/test/DeviceTest.cpp
+++ b/neuralnetworks/1.2/utils/test/DeviceTest.cpp
@@ -772,7 +772,7 @@
EXPECT_NE(result.value(), nullptr);
}
-TEST(DeviceTest, prepareModelFromCacheError) {
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice).value();
@@ -790,6 +790,23 @@
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ V1_0::ErrorStatus::NONE, V1_0::ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
TEST(DeviceTest, prepareModelFromCacheNullptrError) {
// setup call
const auto mockDevice = createMockDevice();
diff --git a/neuralnetworks/1.3/types.hal b/neuralnetworks/1.3/types.hal
index 51837fe..a5dbd5e 100644
--- a/neuralnetworks/1.3/types.hal
+++ b/neuralnetworks/1.3/types.hal
@@ -5340,7 +5340,6 @@
HIGH,
};
-
/**
* The capabilities of a driver.
*
diff --git a/neuralnetworks/1.3/types.t b/neuralnetworks/1.3/types.t
index 2901d18..9f69c9e 100644
--- a/neuralnetworks/1.3/types.t
+++ b/neuralnetworks/1.3/types.t
@@ -99,7 +99,6 @@
HIGH,
};
-
/**
* The capabilities of a driver.
*
diff --git a/neuralnetworks/1.3/utils/test/DeviceTest.cpp b/neuralnetworks/1.3/utils/test/DeviceTest.cpp
index f260990..2d1b2f2 100644
--- a/neuralnetworks/1.3/utils/test/DeviceTest.cpp
+++ b/neuralnetworks/1.3/utils/test/DeviceTest.cpp
@@ -794,7 +794,7 @@
EXPECT_NE(result.value(), nullptr);
}
-TEST(DeviceTest, prepareModelFromCacheError) {
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
// setup call
const auto mockDevice = createMockDevice();
const auto device = Device::create(kName, mockDevice).value();
@@ -812,6 +812,23 @@
EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
}
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache_1_3(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ V1_3::ErrorStatus::NONE, V1_3::ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
TEST(DeviceTest, prepareModelFromCacheNullptrError) {
// setup call
const auto mockDevice = createMockDevice();
diff --git a/neuralnetworks/TEST_MAPPING b/neuralnetworks/TEST_MAPPING
index 5d168d2..d296828 100644
--- a/neuralnetworks/TEST_MAPPING
+++ b/neuralnetworks/TEST_MAPPING
@@ -16,6 +16,9 @@
"name": "neuralnetworks_utils_hal_1_3_test"
},
{
+ "name": "neuralnetworks_utils_hal_aidl_test"
+ },
+ {
"name": "VtsHalNeuralnetworksV1_0TargetTest",
"options": [
{
diff --git a/neuralnetworks/aidl/utils/Android.bp b/neuralnetworks/aidl/utils/Android.bp
index 2673cae..476dac9 100644
--- a/neuralnetworks/aidl/utils/Android.bp
+++ b/neuralnetworks/aidl/utils/Android.bp
@@ -29,10 +29,12 @@
srcs: ["src/*"],
local_include_dirs: ["include/nnapi/hal/aidl/"],
export_include_dirs: ["include"],
+ cflags: ["-Wthread-safety"],
static_libs: [
"libarect",
"neuralnetworks_types",
"neuralnetworks_utils_hal_common",
+ "neuralnetworks_utils_hal_1_0",
],
shared_libs: [
"android.hardware.neuralnetworks-V1-ndk_platform",
@@ -41,3 +43,38 @@
"libnativewindow",
],
}
+
+cc_test {
+ name: "neuralnetworks_utils_hal_aidl_test",
+ defaults: ["neuralnetworks_utils_defaults"],
+ srcs: [
+ "test/*.cpp",
+ ],
+ static_libs: [
+ "android.hardware.common-V2-ndk_platform",
+ "android.hardware.neuralnetworks-V1-ndk_platform",
+ "libgmock",
+ "libneuralnetworks_common",
+ "neuralnetworks_types",
+ "neuralnetworks_utils_hal_aidl",
+ "neuralnetworks_utils_hal_common",
+ ],
+ shared_libs: [
+ "android.hidl.allocator@1.0",
+ "libbase",
+ "libbinder_ndk",
+ "libcutils",
+ "libhidlbase",
+ "libhidlmemory",
+ "liblog",
+ "libnativewindow",
+ "libutils",
+ ],
+ cflags: [
+ /* GMOCK defines functions for printing all MOCK_DEVICE arguments and
+ * MockDevice contains a string pointer which triggers a warning in the
+ * base logging library. */
+ "-Wno-user-defined-warnings",
+ ],
+ test_suites: ["general-tests"],
+}
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h
new file mode 100644
index 0000000..46190c4
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Buffer.h
@@ -0,0 +1,56 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
+
+#include <aidl/android/hardware/neuralnetworks/IBuffer.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <memory>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IBuffer to nn::IBuffer.
+class Buffer final : public nn::IBuffer {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const Buffer>> create(
+ std::shared_ptr<aidl_hal::IBuffer> buffer, nn::Request::MemoryDomainToken token);
+
+ Buffer(PrivateConstructorTag tag, std::shared_ptr<aidl_hal::IBuffer> buffer,
+ nn::Request::MemoryDomainToken token);
+
+ nn::Request::MemoryDomainToken getToken() const override;
+
+ nn::GeneralResult<void> copyTo(const nn::SharedMemory& dst) const override;
+ nn::GeneralResult<void> copyFrom(const nn::SharedMemory& src,
+ const nn::Dimensions& dimensions) const override;
+
+ private:
+ const std::shared_ptr<aidl_hal::IBuffer> kBuffer;
+ const nn::Request::MemoryDomainToken kToken;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_BUFFER_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h
new file mode 100644
index 0000000..8651912
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Callbacks.h
@@ -0,0 +1,53 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
+
+#include <aidl/android/hardware/neuralnetworks/BnPreparedModelCallback.h>
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/TransferValue.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// An AIDL callback class to receive the results of IDevice::prepareModel* asynchronously.
+class PreparedModelCallback final : public BnPreparedModelCallback,
+ public hal::utils::IProtectedCallback {
+ public:
+ using Data = nn::GeneralResult<nn::SharedPreparedModel>;
+
+ ndk::ScopedAStatus notify(ErrorStatus status,
+ const std::shared_ptr<IPreparedModel>& preparedModel) override;
+
+ void notifyAsDeadObject() override;
+
+ Data get();
+
+ private:
+ hal::utils::TransferValue<Data> mData;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_CALLBACKS_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
index 1b2f69c..4922a6e 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Conversions.h
@@ -46,6 +46,7 @@
#include <aidl/android/hardware/neuralnetworks/SymmPerChannelQuantParams.h>
#include <aidl/android/hardware/neuralnetworks/Timing.h>
+#include <android/binder_auto_utils.h>
#include <nnapi/Result.h>
#include <nnapi/Types.h>
#include <nnapi/hal/CommonUtils.h>
@@ -96,7 +97,11 @@
const aidl_hal::ExtensionOperandTypeInformation& operandTypeInformation);
GeneralResult<SharedHandle> unvalidatedConvert(
const ::aidl::android::hardware::common::NativeHandle& handle);
+GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence);
+GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities);
+GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType);
+GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus);
GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference);
GeneralResult<SharedMemory> convert(const aidl_hal::Memory& memory);
@@ -106,9 +111,14 @@
GeneralResult<Priority> convert(const aidl_hal::Priority& priority);
GeneralResult<Request::MemoryPool> convert(const aidl_hal::RequestMemoryPool& memoryPool);
GeneralResult<Request> convert(const aidl_hal::Request& request);
+GeneralResult<Timing> convert(const aidl_hal::Timing& timing);
+GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence);
+GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension);
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& outputShapes);
GeneralResult<std::vector<SharedMemory>> convert(const std::vector<aidl_hal::Memory>& memories);
+GeneralResult<std::vector<OutputShape>> convert(
+ const std::vector<aidl_hal::OutputShape>& outputShapes);
GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec);
@@ -118,14 +128,62 @@
namespace nn = ::android::nn;
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken);
+nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc);
+nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole);
+nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Memory> unvalidatedConvert(const nn::SharedMemory& memory);
nn::GeneralResult<OutputShape> unvalidatedConvert(const nn::OutputShape& outputShape);
nn::GeneralResult<ErrorStatus> unvalidatedConvert(const nn::ErrorStatus& errorStatus);
+nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
+ const nn::ExecutionPreference& executionPreference);
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType);
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(const nn::Operand::LifeTime& operandLifeTime);
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location);
+nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
+ const nn::Operand::ExtraParams& extraParams);
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand);
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType);
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation);
+nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph);
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues);
+nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
+ const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix);
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model);
+nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority);
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request);
+nn::GeneralResult<RequestArgument> unvalidatedConvert(const nn::Request::Argument& requestArgument);
+nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool);
+nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration);
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint);
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence);
+nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle);
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
+ const nn::SharedHandle& handle);
+nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken);
+nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc);
+nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming);
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory);
nn::GeneralResult<ErrorStatus> convert(const nn::ErrorStatus& errorStatus);
+nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference);
+nn::GeneralResult<Model> convert(const nn::Model& model);
+nn::GeneralResult<Priority> convert(const nn::Priority& priority);
+nn::GeneralResult<Request> convert(const nn::Request& request);
+nn::GeneralResult<Timing> convert(const nn::Timing& timing);
+nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration);
+nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& optionalTimePoint);
+
+nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles);
nn::GeneralResult<std::vector<OutputShape>> convert(
const std::vector<nn::OutputShape>& outputShapes);
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SharedHandle>& handles);
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SyncFence>& syncFences);
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec);
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
new file mode 100644
index 0000000..eb194e3
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Device.h
@@ -0,0 +1,98 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
+
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/OperandTypes.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IDevice to nn::IDevice.
+class Device final : public nn::IDevice {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const Device>> create(
+ std::string name, std::shared_ptr<aidl_hal::IDevice> device);
+
+ Device(PrivateConstructorTag tag, std::string name, std::string versionString,
+ nn::DeviceType deviceType, std::vector<nn::Extension> extensions,
+ nn::Capabilities capabilities, std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
+ std::shared_ptr<aidl_hal::IDevice> device, DeathHandler deathHandler);
+
+ const std::string& getName() const override;
+ const std::string& getVersionString() const override;
+ nn::Version getFeatureLevel() const override;
+ nn::DeviceType getType() const override;
+ bool isUpdatable() const override;
+ const std::vector<nn::Extension>& getSupportedExtensions() const override;
+ const nn::Capabilities& getCapabilities() const override;
+ std::pair<uint32_t, uint32_t> getNumberOfCacheFilesNeeded() const override;
+
+ nn::GeneralResult<void> wait() const override;
+
+ nn::GeneralResult<std::vector<bool>> getSupportedOperations(
+ const nn::Model& model) const override;
+
+ nn::GeneralResult<nn::SharedPreparedModel> prepareModel(
+ const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache,
+ const nn::CacheToken& token) const override;
+
+ nn::GeneralResult<nn::SharedPreparedModel> prepareModelFromCache(
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache,
+ const nn::CacheToken& token) const override;
+
+ nn::GeneralResult<nn::SharedBuffer> allocate(
+ const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
+ const std::vector<nn::BufferRole>& inputRoles,
+ const std::vector<nn::BufferRole>& outputRoles) const override;
+
+ DeathMonitor* getDeathMonitor() const;
+
+ private:
+ const std::string kName;
+ const std::string kVersionString;
+ const nn::DeviceType kDeviceType;
+ const std::vector<nn::Extension> kExtensions;
+ const nn::Capabilities kCapabilities;
+ const std::pair<uint32_t, uint32_t> kNumberOfCacheFilesNeeded;
+ const std::shared_ptr<aidl_hal::IDevice> kDevice;
+ const DeathHandler kDeathHandler;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_DEVICE_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
new file mode 100644
index 0000000..9b28588
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/PreparedModel.h
@@ -0,0 +1,70 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
+
+#include <aidl/android/hardware/neuralnetworks/IPreparedModel.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/aidl/ProtectCallback.h>
+
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Class that adapts aidl_hal::IPreparedModel to nn::IPreparedModel.
+class PreparedModel final : public nn::IPreparedModel,
+ public std::enable_shared_from_this<PreparedModel> {
+ struct PrivateConstructorTag {};
+
+ public:
+ static nn::GeneralResult<std::shared_ptr<const PreparedModel>> create(
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel);
+
+ PreparedModel(PrivateConstructorTag tag,
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel);
+
+ nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> execute(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration) const override;
+
+ nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>> executeFenced(
+ const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
+ nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const nn::OptionalDuration& timeoutDurationAfterFence) const override;
+
+ nn::GeneralResult<nn::SharedBurst> configureExecutionBurst() const override;
+
+ std::any getUnderlyingResource() const override;
+
+ private:
+ const std::shared_ptr<aidl_hal::IPreparedModel> kPreparedModel;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PREPARED_MODEL_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
new file mode 100644
index 0000000..ab1108c
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/ProtectCallback.h
@@ -0,0 +1,81 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
+
+#include <android-base/scopeguard.h>
+#include <android-base/thread_annotations.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/ProtectCallback.h>
+
+#include <functional>
+#include <mutex>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+// Thread safe class
+class DeathMonitor final {
+ public:
+ static void serviceDied(void* cookie);
+ void serviceDied();
+ // Precondition: `killable` must be non-null.
+ void add(hal::utils::IProtectedCallback* killable) const;
+ // Precondition: `killable` must be non-null.
+ void remove(hal::utils::IProtectedCallback* killable) const;
+
+ private:
+ mutable std::mutex mMutex;
+ mutable std::vector<hal::utils::IProtectedCallback*> mObjects GUARDED_BY(mMutex);
+};
+
+class DeathHandler final {
+ public:
+ static nn::GeneralResult<DeathHandler> create(std::shared_ptr<ndk::ICInterface> object);
+
+ DeathHandler(const DeathHandler&) = delete;
+ DeathHandler(DeathHandler&&) noexcept = default;
+ DeathHandler& operator=(const DeathHandler&) = delete;
+ DeathHandler& operator=(DeathHandler&&) noexcept = delete;
+ ~DeathHandler();
+
+ using Cleanup = std::function<void()>;
+ // Precondition: `killable` must be non-null.
+ [[nodiscard]] ::android::base::ScopeGuard<Cleanup> protectCallback(
+ hal::utils::IProtectedCallback* killable) const;
+
+ std::shared_ptr<DeathMonitor> getDeathMonitor() const { return kDeathMonitor; }
+
+ private:
+ DeathHandler(std::shared_ptr<ndk::ICInterface> object,
+ ndk::ScopedAIBinder_DeathRecipient deathRecipient,
+ std::shared_ptr<DeathMonitor> deathMonitor);
+
+ std::shared_ptr<ndk::ICInterface> kObject;
+ ndk::ScopedAIBinder_DeathRecipient kDeathRecipient;
+ std::shared_ptr<DeathMonitor> kDeathMonitor;
+};
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_PROTECT_CALLBACK_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h
new file mode 100644
index 0000000..b4587ac
--- /dev/null
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Service.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
+
+#include <nnapi/IDevice.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+
+#include <string>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<nn::SharedDevice> getDevice(const std::string& name);
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_SERVICE_H
diff --git a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
index 79b511d..58dcfe3 100644
--- a/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
+++ b/neuralnetworks/aidl/utils/include/nnapi/hal/aidl/Utils.h
@@ -23,6 +23,7 @@
#include <nnapi/Result.h>
#include <nnapi/Types.h>
#include <nnapi/Validation.h>
+#include <nnapi/hal/HandleError.h>
namespace aidl::android::hardware::neuralnetworks::utils {
@@ -52,6 +53,12 @@
nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool);
nn::GeneralResult<Model> clone(const Model& model);
+nn::GeneralResult<void> handleTransportError(const ndk::ScopedAStatus& ret);
+
+#define HANDLE_ASTATUS(ret) \
+ for (const auto status = handleTransportError(ret); !status.ok();) \
+ return NN_ERROR(status.error().code) << status.error().message << ": "
+
} // namespace aidl::android::hardware::neuralnetworks::utils
#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_H
diff --git a/neuralnetworks/aidl/utils/src/Buffer.cpp b/neuralnetworks/aidl/utils/src/Buffer.cpp
new file mode 100644
index 0000000..c729a68
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Buffer.cpp
@@ -0,0 +1,78 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Buffer.h"
+
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+
+#include "Conversions.h"
+#include "Utils.h"
+#include "nnapi/hal/aidl/Conversions.h"
+
+#include <memory>
+#include <utility>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<std::shared_ptr<const Buffer>> Buffer::create(
+ std::shared_ptr<aidl_hal::IBuffer> buffer, nn::Request::MemoryDomainToken token) {
+ if (buffer == nullptr) {
+ return NN_ERROR() << "aidl_hal::utils::Buffer::create must have non-null buffer";
+ }
+ if (token == static_cast<nn::Request::MemoryDomainToken>(0)) {
+ return NN_ERROR() << "aidl_hal::utils::Buffer::create must have non-zero token";
+ }
+
+ return std::make_shared<const Buffer>(PrivateConstructorTag{}, std::move(buffer), token);
+}
+
+Buffer::Buffer(PrivateConstructorTag /*tag*/, std::shared_ptr<aidl_hal::IBuffer> buffer,
+ nn::Request::MemoryDomainToken token)
+ : kBuffer(std::move(buffer)), kToken(token) {
+ CHECK(kBuffer != nullptr);
+ CHECK(kToken != static_cast<nn::Request::MemoryDomainToken>(0));
+}
+
+nn::Request::MemoryDomainToken Buffer::getToken() const {
+ return kToken;
+}
+
+nn::GeneralResult<void> Buffer::copyTo(const nn::SharedMemory& dst) const {
+ const auto aidlDst = NN_TRY(convert(dst));
+
+ const auto ret = kBuffer->copyTo(aidlDst);
+ HANDLE_ASTATUS(ret) << "IBuffer::copyTo failed";
+
+ return {};
+}
+
+nn::GeneralResult<void> Buffer::copyFrom(const nn::SharedMemory& src,
+ const nn::Dimensions& dimensions) const {
+ const auto aidlSrc = NN_TRY(convert(src));
+ const auto aidlDimensions = NN_TRY(toSigned(dimensions));
+
+ const auto ret = kBuffer->copyFrom(aidlSrc, aidlDimensions);
+ HANDLE_ASTATUS(ret) << "IBuffer::copyFrom failed";
+
+ return {};
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Callbacks.cpp b/neuralnetworks/aidl/utils/src/Callbacks.cpp
new file mode 100644
index 0000000..8055665
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Callbacks.cpp
@@ -0,0 +1,61 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Callbacks.h"
+
+#include "Conversions.h"
+#include "PreparedModel.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+
+#include <utility>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+// Converts the results of IDevice::prepareModel* to the NN canonical format. On success, this
+// function returns with a non-null nn::SharedPreparedModel with a feature level of
+// nn::Version::ANDROID_S. On failure, this function returns with the appropriate nn::GeneralError.
+nn::GeneralResult<nn::SharedPreparedModel> prepareModelCallback(
+ ErrorStatus status, const std::shared_ptr<IPreparedModel>& preparedModel) {
+ HANDLE_HAL_STATUS(status) << "model preparation failed with " << toString(status);
+ return NN_TRY(PreparedModel::create(preparedModel));
+}
+
+} // namespace
+
+ndk::ScopedAStatus PreparedModelCallback::notify(
+ ErrorStatus status, const std::shared_ptr<IPreparedModel>& preparedModel) {
+ mData.put(prepareModelCallback(status, preparedModel));
+ return ndk::ScopedAStatus::ok();
+}
+
+void PreparedModelCallback::notifyAsDeadObject() {
+ mData.put(NN_ERROR(nn::ErrorStatus::DEAD_OBJECT) << "Dead object");
+}
+
+PreparedModelCallback::Data PreparedModelCallback::get() {
+ return mData.take();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Conversions.cpp b/neuralnetworks/aidl/utils/src/Conversions.cpp
index db3504b..5d9c55b 100644
--- a/neuralnetworks/aidl/utils/src/Conversions.cpp
+++ b/neuralnetworks/aidl/utils/src/Conversions.cpp
@@ -18,6 +18,8 @@
#include <aidl/android/hardware/common/NativeHandle.h>
#include <android-base/logging.h>
+#include <android-base/unique_fd.h>
+#include <android/binder_auto_utils.h>
#include <android/hardware_buffer.h>
#include <cutils/native_handle.h>
#include <nnapi/OperandTypes.h>
@@ -42,14 +44,17 @@
#define VERIFY_NON_NEGATIVE(value) \
while (UNLIKELY(value < 0)) return NN_ERROR()
-namespace {
+#define VERIFY_LE_INT32_MAX(value) \
+ while (UNLIKELY(value > std::numeric_limits<int32_t>::max())) return NN_ERROR()
+namespace {
template <typename Type>
constexpr std::underlying_type_t<Type> underlyingType(Type value) {
return static_cast<std::underlying_type_t<Type>>(value);
}
constexpr auto kVersion = android::nn::Version::ANDROID_S;
+constexpr int64_t kNoTiming = -1;
} // namespace
@@ -134,13 +139,8 @@
std::vector<base::unique_fd> fds;
fds.reserve(aidlNativeHandle.fds.size());
for (const auto& fd : aidlNativeHandle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
- // here?
- return NN_ERROR() << "Failed to dup the fd";
- }
- fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(dupFd(fd.get()));
+ fds.emplace_back(duplicatedFd.release());
}
return Handle{.fds = std::move(fds), .ints = aidlNativeHandle.ints};
@@ -157,16 +157,12 @@
using UniqueNativeHandle = std::unique_ptr<native_handle_t, NativeHandleDeleter>;
-static nn::GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(
- const NativeHandle& handle) {
+static GeneralResult<UniqueNativeHandle> nativeHandleFromAidlHandle(const NativeHandle& handle) {
std::vector<base::unique_fd> fds;
fds.reserve(handle.fds.size());
for (const auto& fd : handle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- return NN_ERROR() << "Failed to dup the fd";
- }
- fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(dupFd(fd.get()));
+ fds.emplace_back(duplicatedFd.release());
}
constexpr size_t kIntMax = std::numeric_limits<int>::max();
@@ -382,14 +378,14 @@
GeneralResult<SharedMemory> unvalidatedConvert(const aidl_hal::Memory& memory) {
VERIFY_NON_NEGATIVE(memory.size) << "Memory size must not be negative";
- if (memory.size > std::numeric_limits<uint32_t>::max()) {
+ if (memory.size > std::numeric_limits<size_t>::max()) {
return NN_ERROR() << "Memory: size must be <= std::numeric_limits<size_t>::max()";
}
if (memory.name != "hardware_buffer_blob") {
return std::make_shared<const Memory>(Memory{
.handle = NN_TRY(unvalidatedConvertHelper(memory.handle)),
- .size = static_cast<uint32_t>(memory.size),
+ .size = static_cast<size_t>(memory.size),
.name = memory.name,
});
}
@@ -434,11 +430,28 @@
return std::make_shared<const Memory>(Memory{
.handle = HardwareBufferHandle(hardwareBuffer, /*takeOwnership=*/true),
- .size = static_cast<uint32_t>(memory.size),
+ .size = static_cast<size_t>(memory.size),
.name = memory.name,
});
}
+GeneralResult<Timing> unvalidatedConvert(const aidl_hal::Timing& timing) {
+ if (timing.timeInDriver < -1) {
+ return NN_ERROR() << "Timing: timeInDriver must not be less than -1";
+ }
+ if (timing.timeOnDevice < -1) {
+ return NN_ERROR() << "Timing: timeOnDevice must not be less than -1";
+ }
+ constexpr auto convertTiming = [](int64_t halTiming) -> OptionalDuration {
+ if (halTiming == kNoTiming) {
+ return {};
+ }
+ return nn::Duration(static_cast<uint64_t>(halTiming));
+ };
+ return Timing{.timeOnDevice = convertTiming(timing.timeOnDevice),
+ .timeInDriver = convertTiming(timing.timeInDriver)};
+}
+
GeneralResult<Model::OperandValues> unvalidatedConvert(const std::vector<uint8_t>& operandValues) {
return Model::OperandValues(operandValues.data(), operandValues.size());
}
@@ -515,6 +528,23 @@
return std::make_shared<const Handle>(NN_TRY(unvalidatedConvertHelper(aidlNativeHandle)));
}
+GeneralResult<SyncFence> unvalidatedConvert(const ndk::ScopedFileDescriptor& syncFence) {
+ auto duplicatedFd = NN_TRY(dupFd(syncFence.get()));
+ return SyncFence::create(std::move(duplicatedFd));
+}
+
+GeneralResult<Capabilities> convert(const aidl_hal::Capabilities& capabilities) {
+ return validatedConvert(capabilities);
+}
+
+GeneralResult<DeviceType> convert(const aidl_hal::DeviceType& deviceType) {
+ return validatedConvert(deviceType);
+}
+
+GeneralResult<ErrorStatus> convert(const aidl_hal::ErrorStatus& errorStatus) {
+ return validatedConvert(errorStatus);
+}
+
GeneralResult<ExecutionPreference> convert(
const aidl_hal::ExecutionPreference& executionPreference) {
return validatedConvert(executionPreference);
@@ -548,6 +578,18 @@
return validatedConvert(request);
}
+GeneralResult<Timing> convert(const aidl_hal::Timing& timing) {
+ return validatedConvert(timing);
+}
+
+GeneralResult<SyncFence> convert(const ndk::ScopedFileDescriptor& syncFence) {
+ return unvalidatedConvert(syncFence);
+}
+
+GeneralResult<std::vector<Extension>> convert(const std::vector<aidl_hal::Extension>& extension) {
+ return validatedConvert(extension);
+}
+
GeneralResult<std::vector<Operation>> convert(const std::vector<aidl_hal::Operation>& operations) {
return unvalidatedConvert(operations);
}
@@ -556,6 +598,11 @@
return validatedConvert(memories);
}
+GeneralResult<std::vector<OutputShape>> convert(
+ const std::vector<aidl_hal::OutputShape>& outputShapes) {
+ return validatedConvert(outputShapes);
+}
+
GeneralResult<std::vector<uint32_t>> toUnsigned(const std::vector<int32_t>& vec) {
if (!std::all_of(vec.begin(), vec.end(), [](int32_t v) { return v >= 0; })) {
return NN_ERROR() << "Negative value passed to conversion from signed to unsigned";
@@ -575,14 +622,21 @@
template <typename Type>
nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvertVec(
const std::vector<Type>& arguments) {
- std::vector<UnvalidatedConvertOutput<Type>> halObject(arguments.size());
- for (size_t i = 0; i < arguments.size(); ++i) {
- halObject[i] = NN_TRY(unvalidatedConvert(arguments[i]));
+ std::vector<UnvalidatedConvertOutput<Type>> halObject;
+ halObject.reserve(arguments.size());
+ for (const auto& argument : arguments) {
+ halObject.push_back(NN_TRY(unvalidatedConvert(argument)));
}
return halObject;
}
template <typename Type>
+nn::GeneralResult<std::vector<UnvalidatedConvertOutput<Type>>> unvalidatedConvert(
+ const std::vector<Type>& arguments) {
+ return unvalidatedConvertVec(arguments);
+}
+
+template <typename Type>
nn::GeneralResult<UnvalidatedConvertOutput<Type>> validatedConvert(const Type& canonical) {
const auto maybeVersion = nn::validate(canonical);
if (!maybeVersion.has_value()) {
@@ -609,29 +663,29 @@
common::NativeHandle aidlNativeHandle;
aidlNativeHandle.fds.reserve(handle.fds.size());
for (const auto& fd : handle.fds) {
- const int dupFd = dup(fd.get());
- if (dupFd == -1) {
- // TODO(b/120417090): is ANEURALNETWORKS_UNEXPECTED_NULL the correct error to return
- // here?
- return NN_ERROR() << "Failed to dup the fd";
- }
- aidlNativeHandle.fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(nn::dupFd(fd.get()));
+ aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
}
aidlNativeHandle.ints = handle.ints;
return aidlNativeHandle;
}
+// Helper template for std::visit
+template <class... Ts>
+struct overloaded : Ts... {
+ using Ts::operator()...;
+};
+template <class... Ts>
+overloaded(Ts...)->overloaded<Ts...>;
+
static nn::GeneralResult<common::NativeHandle> aidlHandleFromNativeHandle(
const native_handle_t& handle) {
common::NativeHandle aidlNativeHandle;
aidlNativeHandle.fds.reserve(handle.numFds);
for (int i = 0; i < handle.numFds; ++i) {
- const int dupFd = dup(handle.data[i]);
- if (dupFd == -1) {
- return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "Failed to dup the fd";
- }
- aidlNativeHandle.fds.emplace_back(dupFd);
+ auto duplicatedFd = NN_TRY(nn::dupFd(handle.data[i]));
+ aidlNativeHandle.fds.emplace_back(duplicatedFd.release());
}
aidlNativeHandle.ints = std::vector<int>(&handle.data[handle.numFds],
@@ -642,6 +696,30 @@
} // namespace
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(const nn::CacheToken& cacheToken) {
+ return std::vector<uint8_t>(cacheToken.begin(), cacheToken.end());
+}
+
+nn::GeneralResult<BufferDesc> unvalidatedConvert(const nn::BufferDesc& bufferDesc) {
+ return BufferDesc{.dimensions = NN_TRY(toSigned(bufferDesc.dimensions))};
+}
+
+nn::GeneralResult<BufferRole> unvalidatedConvert(const nn::BufferRole& bufferRole) {
+ VERIFY_LE_INT32_MAX(bufferRole.modelIndex)
+ << "BufferRole: modelIndex must be <= std::numeric_limits<int32_t>::max()";
+ VERIFY_LE_INT32_MAX(bufferRole.ioIndex)
+ << "BufferRole: ioIndex must be <= std::numeric_limits<int32_t>::max()";
+ return BufferRole{
+ .modelIndex = static_cast<int32_t>(bufferRole.modelIndex),
+ .ioIndex = static_cast<int32_t>(bufferRole.ioIndex),
+ .frequency = bufferRole.frequency,
+ };
+}
+
+nn::GeneralResult<bool> unvalidatedConvert(const nn::MeasureTiming& measureTiming) {
+ return measureTiming == nn::MeasureTiming::YES;
+}
+
nn::GeneralResult<common::NativeHandle> unvalidatedConvert(const nn::SharedHandle& sharedHandle) {
CHECK(sharedHandle != nullptr);
return unvalidatedConvert(*sharedHandle);
@@ -707,6 +785,230 @@
.isSufficient = outputShape.isSufficient};
}
+nn::GeneralResult<ExecutionPreference> unvalidatedConvert(
+ const nn::ExecutionPreference& executionPreference) {
+ return static_cast<ExecutionPreference>(executionPreference);
+}
+
+nn::GeneralResult<OperandType> unvalidatedConvert(const nn::OperandType& operandType) {
+ return static_cast<OperandType>(operandType);
+}
+
+nn::GeneralResult<OperandLifeTime> unvalidatedConvert(
+ const nn::Operand::LifeTime& operandLifeTime) {
+ return static_cast<OperandLifeTime>(operandLifeTime);
+}
+
+nn::GeneralResult<DataLocation> unvalidatedConvert(const nn::DataLocation& location) {
+ VERIFY_LE_INT32_MAX(location.poolIndex)
+ << "DataLocation: pool index must be <= std::numeric_limits<int32_t>::max()";
+ return DataLocation{
+ .poolIndex = static_cast<int32_t>(location.poolIndex),
+ .offset = static_cast<int64_t>(location.offset),
+ .length = static_cast<int64_t>(location.length),
+ };
+}
+
+nn::GeneralResult<std::optional<OperandExtraParams>> unvalidatedConvert(
+ const nn::Operand::ExtraParams& extraParams) {
+ return std::visit(
+ overloaded{
+ [](const nn::Operand::NoParams&)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ return std::nullopt;
+ },
+ [](const nn::Operand::SymmPerChannelQuantParams& symmPerChannelQuantParams)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ if (symmPerChannelQuantParams.channelDim >
+ std::numeric_limits<int32_t>::max()) {
+ // Using explicit type conversion because std::optional in successful
+ // result confuses the compiler.
+ return (NN_ERROR() << "symmPerChannelQuantParams.channelDim must be <= "
+ "std::numeric_limits<int32_t>::max(), received: "
+ << symmPerChannelQuantParams.channelDim)
+ .
+ operator nn::GeneralResult<std::optional<OperandExtraParams>>();
+ }
+ return OperandExtraParams::make<OperandExtraParams::Tag::channelQuant>(
+ SymmPerChannelQuantParams{
+ .scales = symmPerChannelQuantParams.scales,
+ .channelDim = static_cast<int32_t>(
+ symmPerChannelQuantParams.channelDim),
+ });
+ },
+ [](const nn::Operand::ExtensionParams& extensionParams)
+ -> nn::GeneralResult<std::optional<OperandExtraParams>> {
+ return OperandExtraParams::make<OperandExtraParams::Tag::extension>(
+ extensionParams);
+ },
+ },
+ extraParams);
+}
+
+nn::GeneralResult<Operand> unvalidatedConvert(const nn::Operand& operand) {
+ return Operand{
+ .type = NN_TRY(unvalidatedConvert(operand.type)),
+ .dimensions = NN_TRY(toSigned(operand.dimensions)),
+ .scale = operand.scale,
+ .zeroPoint = operand.zeroPoint,
+ .lifetime = NN_TRY(unvalidatedConvert(operand.lifetime)),
+ .location = NN_TRY(unvalidatedConvert(operand.location)),
+ .extraParams = NN_TRY(unvalidatedConvert(operand.extraParams)),
+ };
+}
+
+nn::GeneralResult<OperationType> unvalidatedConvert(const nn::OperationType& operationType) {
+ return static_cast<OperationType>(operationType);
+}
+
+nn::GeneralResult<Operation> unvalidatedConvert(const nn::Operation& operation) {
+ return Operation{
+ .type = NN_TRY(unvalidatedConvert(operation.type)),
+ .inputs = NN_TRY(toSigned(operation.inputs)),
+ .outputs = NN_TRY(toSigned(operation.outputs)),
+ };
+}
+
+nn::GeneralResult<Subgraph> unvalidatedConvert(const nn::Model::Subgraph& subgraph) {
+ return Subgraph{
+ .operands = NN_TRY(unvalidatedConvert(subgraph.operands)),
+ .operations = NN_TRY(unvalidatedConvert(subgraph.operations)),
+ .inputIndexes = NN_TRY(toSigned(subgraph.inputIndexes)),
+ .outputIndexes = NN_TRY(toSigned(subgraph.outputIndexes)),
+ };
+}
+
+nn::GeneralResult<std::vector<uint8_t>> unvalidatedConvert(
+ const nn::Model::OperandValues& operandValues) {
+ return std::vector<uint8_t>(operandValues.data(), operandValues.data() + operandValues.size());
+}
+
+nn::GeneralResult<ExtensionNameAndPrefix> unvalidatedConvert(
+ const nn::Model::ExtensionNameAndPrefix& extensionNameToPrefix) {
+ return ExtensionNameAndPrefix{
+ .name = extensionNameToPrefix.name,
+ .prefix = extensionNameToPrefix.prefix,
+ };
+}
+
+nn::GeneralResult<Model> unvalidatedConvert(const nn::Model& model) {
+ return Model{
+ .main = NN_TRY(unvalidatedConvert(model.main)),
+ .referenced = NN_TRY(unvalidatedConvert(model.referenced)),
+ .operandValues = NN_TRY(unvalidatedConvert(model.operandValues)),
+ .pools = NN_TRY(unvalidatedConvert(model.pools)),
+ .relaxComputationFloat32toFloat16 = model.relaxComputationFloat32toFloat16,
+ .extensionNameToPrefix = NN_TRY(unvalidatedConvert(model.extensionNameToPrefix)),
+ };
+}
+
+nn::GeneralResult<Priority> unvalidatedConvert(const nn::Priority& priority) {
+ return static_cast<Priority>(priority);
+}
+
+nn::GeneralResult<Request> unvalidatedConvert(const nn::Request& request) {
+ return Request{
+ .inputs = NN_TRY(unvalidatedConvert(request.inputs)),
+ .outputs = NN_TRY(unvalidatedConvert(request.outputs)),
+ .pools = NN_TRY(unvalidatedConvert(request.pools)),
+ };
+}
+
+nn::GeneralResult<RequestArgument> unvalidatedConvert(
+ const nn::Request::Argument& requestArgument) {
+ if (requestArgument.lifetime == nn::Request::Argument::LifeTime::POINTER) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "Request cannot be unvalidatedConverted because it contains pointer-based memory";
+ }
+ const bool hasNoValue = requestArgument.lifetime == nn::Request::Argument::LifeTime::NO_VALUE;
+ return RequestArgument{
+ .hasNoValue = hasNoValue,
+ .location = NN_TRY(unvalidatedConvert(requestArgument.location)),
+ .dimensions = NN_TRY(toSigned(requestArgument.dimensions)),
+ };
+}
+
+nn::GeneralResult<RequestMemoryPool> unvalidatedConvert(const nn::Request::MemoryPool& memoryPool) {
+ return std::visit(
+ overloaded{
+ [](const nn::SharedMemory& memory) -> nn::GeneralResult<RequestMemoryPool> {
+ return RequestMemoryPool::make<RequestMemoryPool::Tag::pool>(
+ NN_TRY(unvalidatedConvert(memory)));
+ },
+ [](const nn::Request::MemoryDomainToken& token)
+ -> nn::GeneralResult<RequestMemoryPool> {
+ return RequestMemoryPool::make<RequestMemoryPool::Tag::token>(
+ underlyingType(token));
+ },
+ [](const nn::SharedBuffer& /*buffer*/) {
+ return (NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE)
+ << "Unable to make memory pool from IBuffer")
+ .
+ operator nn::GeneralResult<RequestMemoryPool>();
+ },
+ },
+ memoryPool);
+}
+
+nn::GeneralResult<Timing> unvalidatedConvert(const nn::Timing& timing) {
+ return Timing{
+ .timeOnDevice = NN_TRY(unvalidatedConvert(timing.timeOnDevice)),
+ .timeInDriver = NN_TRY(unvalidatedConvert(timing.timeInDriver)),
+ };
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::Duration& duration) {
+ const uint64_t nanoseconds = duration.count();
+ if (nanoseconds > std::numeric_limits<int64_t>::max()) {
+ return std::numeric_limits<int64_t>::max();
+ }
+ return static_cast<int64_t>(nanoseconds);
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalDuration& optionalDuration) {
+ if (!optionalDuration.has_value()) {
+ return kNoTiming;
+ }
+ return unvalidatedConvert(optionalDuration.value());
+}
+
+nn::GeneralResult<int64_t> unvalidatedConvert(const nn::OptionalTimePoint& optionalTimePoint) {
+ if (!optionalTimePoint.has_value()) {
+ return kNoTiming;
+ }
+ return unvalidatedConvert(optionalTimePoint->time_since_epoch());
+}
+
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvert(const nn::SyncFence& syncFence) {
+ auto duplicatedFd = NN_TRY(nn::dupFd(syncFence.getFd()));
+ return ndk::ScopedFileDescriptor(duplicatedFd.release());
+}
+
+nn::GeneralResult<ndk::ScopedFileDescriptor> unvalidatedConvertCache(
+ const nn::SharedHandle& handle) {
+ if (handle->ints.size() != 0) {
+ NN_ERROR() << "Cache handle must not contain ints";
+ }
+ if (handle->fds.size() != 1) {
+ NN_ERROR() << "Cache handle must contain exactly one fd but contains "
+ << handle->fds.size();
+ }
+ auto duplicatedFd = NN_TRY(nn::dupFd(handle->fds.front().get()));
+ return ndk::ScopedFileDescriptor(duplicatedFd.release());
+}
+
+nn::GeneralResult<std::vector<uint8_t>> convert(const nn::CacheToken& cacheToken) {
+ return unvalidatedConvert(cacheToken);
+}
+
+nn::GeneralResult<BufferDesc> convert(const nn::BufferDesc& bufferDesc) {
+ return validatedConvert(bufferDesc);
+}
+
+nn::GeneralResult<bool> convert(const nn::MeasureTiming& measureTiming) {
+ return validatedConvert(measureTiming);
+}
+
nn::GeneralResult<Memory> convert(const nn::SharedMemory& memory) {
return validatedConvert(memory);
}
@@ -715,11 +1017,62 @@
return validatedConvert(errorStatus);
}
+nn::GeneralResult<ExecutionPreference> convert(const nn::ExecutionPreference& executionPreference) {
+ return validatedConvert(executionPreference);
+}
+
+nn::GeneralResult<Model> convert(const nn::Model& model) {
+ return validatedConvert(model);
+}
+
+nn::GeneralResult<Priority> convert(const nn::Priority& priority) {
+ return validatedConvert(priority);
+}
+
+nn::GeneralResult<Request> convert(const nn::Request& request) {
+ return validatedConvert(request);
+}
+
+nn::GeneralResult<Timing> convert(const nn::Timing& timing) {
+ return validatedConvert(timing);
+}
+
+nn::GeneralResult<int64_t> convert(const nn::OptionalDuration& optionalDuration) {
+ return validatedConvert(optionalDuration);
+}
+
+nn::GeneralResult<int64_t> convert(const nn::OptionalTimePoint& outputShapes) {
+ return validatedConvert(outputShapes);
+}
+
+nn::GeneralResult<std::vector<BufferRole>> convert(const std::vector<nn::BufferRole>& bufferRoles) {
+ return validatedConvert(bufferRoles);
+}
+
nn::GeneralResult<std::vector<OutputShape>> convert(
const std::vector<nn::OutputShape>& outputShapes) {
return validatedConvert(outputShapes);
}
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SharedHandle>& cacheHandles) {
+ const auto version = NN_TRY(hal::utils::makeGeneralFailure(nn::validate(cacheHandles)));
+ if (version > kVersion) {
+ return NN_ERROR() << "Insufficient version: " << version << " vs required " << kVersion;
+ }
+ std::vector<ndk::ScopedFileDescriptor> cacheFds;
+ cacheFds.reserve(cacheHandles.size());
+ for (const auto& cacheHandle : cacheHandles) {
+ cacheFds.push_back(NN_TRY(unvalidatedConvertCache(cacheHandle)));
+ }
+ return cacheFds;
+}
+
+nn::GeneralResult<std::vector<ndk::ScopedFileDescriptor>> convert(
+ const std::vector<nn::SyncFence>& syncFences) {
+ return unvalidatedConvert(syncFences);
+}
+
nn::GeneralResult<std::vector<int32_t>> toSigned(const std::vector<uint32_t>& vec) {
if (!std::all_of(vec.begin(), vec.end(),
[](uint32_t v) { return v <= std::numeric_limits<int32_t>::max(); })) {
diff --git a/neuralnetworks/aidl/utils/src/Device.cpp b/neuralnetworks/aidl/utils/src/Device.cpp
new file mode 100644
index 0000000..02ca861
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Device.cpp
@@ -0,0 +1,294 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Device.h"
+
+#include "Buffer.h"
+#include "Callbacks.h"
+#include "Conversions.h"
+#include "PreparedModel.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <aidl/android/hardware/neuralnetworks/IDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/OperandTypes.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/CommonUtils.h>
+
+#include <any>
+#include <functional>
+#include <memory>
+#include <optional>
+#include <string>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+namespace {
+
+nn::GeneralResult<std::vector<std::shared_ptr<IPreparedModel>>> convert(
+ const std::vector<nn::SharedPreparedModel>& preparedModels) {
+ std::vector<std::shared_ptr<IPreparedModel>> aidlPreparedModels(preparedModels.size());
+ for (size_t i = 0; i < preparedModels.size(); ++i) {
+ std::any underlyingResource = preparedModels[i]->getUnderlyingResource();
+ if (const auto* aidlPreparedModel =
+ std::any_cast<std::shared_ptr<aidl_hal::IPreparedModel>>(&underlyingResource)) {
+ aidlPreparedModels[i] = *aidlPreparedModel;
+ } else {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "Unable to convert from nn::IPreparedModel to aidl_hal::IPreparedModel";
+ }
+ }
+ return aidlPreparedModels;
+}
+
+nn::GeneralResult<nn::Capabilities> getCapabilitiesFrom(IDevice* device) {
+ CHECK(device != nullptr);
+ Capabilities capabilities;
+ const auto ret = device->getCapabilities(&capabilities);
+ HANDLE_ASTATUS(ret) << "getCapabilities failed";
+ return nn::convert(capabilities);
+}
+
+nn::GeneralResult<std::string> getVersionStringFrom(aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ std::string version;
+ const auto ret = device->getVersionString(&version);
+ HANDLE_ASTATUS(ret) << "getVersionString failed";
+ return version;
+}
+
+nn::GeneralResult<nn::DeviceType> getDeviceTypeFrom(aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ DeviceType deviceType;
+ const auto ret = device->getType(&deviceType);
+ HANDLE_ASTATUS(ret) << "getDeviceType failed";
+ return nn::convert(deviceType);
+}
+
+nn::GeneralResult<std::vector<nn::Extension>> getSupportedExtensionsFrom(
+ aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ std::vector<Extension> supportedExtensions;
+ const auto ret = device->getSupportedExtensions(&supportedExtensions);
+ HANDLE_ASTATUS(ret) << "getExtensions failed";
+ return nn::convert(supportedExtensions);
+}
+
+nn::GeneralResult<std::pair<uint32_t, uint32_t>> getNumberOfCacheFilesNeededFrom(
+ aidl_hal::IDevice* device) {
+ CHECK(device != nullptr);
+ NumberOfCacheFiles numberOfCacheFiles;
+ const auto ret = device->getNumberOfCacheFilesNeeded(&numberOfCacheFiles);
+ HANDLE_ASTATUS(ret) << "getNumberOfCacheFilesNeeded failed";
+
+ if (numberOfCacheFiles.numDataCache < 0 || numberOfCacheFiles.numModelCache < 0) {
+ return NN_ERROR() << "Driver reported negative numer of cache files needed";
+ }
+ if (static_cast<uint32_t>(numberOfCacheFiles.numModelCache) > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numModelCache files greater "
+ "than allowed max ("
+ << numberOfCacheFiles.numModelCache << " vs "
+ << nn::kMaxNumberOfCacheFiles << ")";
+ }
+ if (static_cast<uint32_t>(numberOfCacheFiles.numDataCache) > nn::kMaxNumberOfCacheFiles) {
+ return NN_ERROR() << "getNumberOfCacheFilesNeeded returned numDataCache files greater "
+ "than allowed max ("
+ << numberOfCacheFiles.numDataCache << " vs " << nn::kMaxNumberOfCacheFiles
+ << ")";
+ }
+ return std::make_pair(numberOfCacheFiles.numDataCache, numberOfCacheFiles.numModelCache);
+}
+
+} // namespace
+
+nn::GeneralResult<std::shared_ptr<const Device>> Device::create(
+ std::string name, std::shared_ptr<aidl_hal::IDevice> device) {
+ if (name.empty()) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "aidl_hal::utils::Device::create must have non-empty name";
+ }
+ if (device == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "aidl_hal::utils::Device::create must have non-null device";
+ }
+
+ auto versionString = NN_TRY(getVersionStringFrom(device.get()));
+ const auto deviceType = NN_TRY(getDeviceTypeFrom(device.get()));
+ auto extensions = NN_TRY(getSupportedExtensionsFrom(device.get()));
+ auto capabilities = NN_TRY(getCapabilitiesFrom(device.get()));
+ const auto numberOfCacheFilesNeeded = NN_TRY(getNumberOfCacheFilesNeededFrom(device.get()));
+
+ auto deathHandler = NN_TRY(DeathHandler::create(device));
+ return std::make_shared<const Device>(
+ PrivateConstructorTag{}, std::move(name), std::move(versionString), deviceType,
+ std::move(extensions), std::move(capabilities), numberOfCacheFilesNeeded,
+ std::move(device), std::move(deathHandler));
+}
+
+Device::Device(PrivateConstructorTag /*tag*/, std::string name, std::string versionString,
+ nn::DeviceType deviceType, std::vector<nn::Extension> extensions,
+ nn::Capabilities capabilities,
+ std::pair<uint32_t, uint32_t> numberOfCacheFilesNeeded,
+ std::shared_ptr<aidl_hal::IDevice> device, DeathHandler deathHandler)
+ : kName(std::move(name)),
+ kVersionString(std::move(versionString)),
+ kDeviceType(deviceType),
+ kExtensions(std::move(extensions)),
+ kCapabilities(std::move(capabilities)),
+ kNumberOfCacheFilesNeeded(numberOfCacheFilesNeeded),
+ kDevice(std::move(device)),
+ kDeathHandler(std::move(deathHandler)) {}
+
+const std::string& Device::getName() const {
+ return kName;
+}
+
+const std::string& Device::getVersionString() const {
+ return kVersionString;
+}
+
+nn::Version Device::getFeatureLevel() const {
+ return nn::Version::ANDROID_S;
+}
+
+nn::DeviceType Device::getType() const {
+ return kDeviceType;
+}
+
+bool Device::isUpdatable() const {
+ return false;
+}
+
+const std::vector<nn::Extension>& Device::getSupportedExtensions() const {
+ return kExtensions;
+}
+
+const nn::Capabilities& Device::getCapabilities() const {
+ return kCapabilities;
+}
+
+std::pair<uint32_t, uint32_t> Device::getNumberOfCacheFilesNeeded() const {
+ return kNumberOfCacheFilesNeeded;
+}
+
+nn::GeneralResult<void> Device::wait() const {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_ping(kDevice->asBinder().get()));
+ HANDLE_ASTATUS(ret) << "ping failed";
+ return {};
+}
+
+nn::GeneralResult<std::vector<bool>> Device::getSupportedOperations(const nn::Model& model) const {
+ // Ensure that model is ready for IPC.
+ std::optional<nn::Model> maybeModelInShared;
+ const nn::Model& modelInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&model, &maybeModelInShared));
+
+ const auto aidlModel = NN_TRY(convert(modelInShared));
+
+ std::vector<bool> supportedOperations;
+ const auto ret = kDevice->getSupportedOperations(aidlModel, &supportedOperations);
+ HANDLE_ASTATUS(ret) << "getSupportedOperations failed";
+
+ return supportedOperations;
+}
+
+nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModel(
+ const nn::Model& model, nn::ExecutionPreference preference, nn::Priority priority,
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+ // Ensure that model is ready for IPC.
+ std::optional<nn::Model> maybeModelInShared;
+ const nn::Model& modelInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&model, &maybeModelInShared));
+
+ const auto aidlModel = NN_TRY(convert(modelInShared));
+ const auto aidlPreference = NN_TRY(convert(preference));
+ const auto aidlPriority = NN_TRY(convert(priority));
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlModelCache = NN_TRY(convert(modelCache));
+ const auto aidlDataCache = NN_TRY(convert(dataCache));
+ const auto aidlToken = NN_TRY(convert(token));
+
+ const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>();
+ const auto scoped = kDeathHandler.protectCallback(cb.get());
+
+ const auto ret = kDevice->prepareModel(aidlModel, aidlPreference, aidlPriority, aidlDeadline,
+ aidlModelCache, aidlDataCache, aidlToken, cb);
+ HANDLE_ASTATUS(ret) << "prepareModel failed";
+
+ return cb->get();
+}
+
+nn::GeneralResult<nn::SharedPreparedModel> Device::prepareModelFromCache(
+ nn::OptionalTimePoint deadline, const std::vector<nn::SharedHandle>& modelCache,
+ const std::vector<nn::SharedHandle>& dataCache, const nn::CacheToken& token) const {
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlModelCache = NN_TRY(convert(modelCache));
+ const auto aidlDataCache = NN_TRY(convert(dataCache));
+ const auto aidlToken = NN_TRY(convert(token));
+
+ const auto cb = ndk::SharedRefBase::make<PreparedModelCallback>();
+ const auto scoped = kDeathHandler.protectCallback(cb.get());
+
+ const auto ret = kDevice->prepareModelFromCache(aidlDeadline, aidlModelCache, aidlDataCache,
+ aidlToken, cb);
+ HANDLE_ASTATUS(ret) << "prepareModelFromCache failed";
+
+ return cb->get();
+}
+
+nn::GeneralResult<nn::SharedBuffer> Device::allocate(
+ const nn::BufferDesc& desc, const std::vector<nn::SharedPreparedModel>& preparedModels,
+ const std::vector<nn::BufferRole>& inputRoles,
+ const std::vector<nn::BufferRole>& outputRoles) const {
+ const auto aidlDesc = NN_TRY(convert(desc));
+ const auto aidlPreparedModels = NN_TRY(convert(preparedModels));
+ const auto aidlInputRoles = NN_TRY(convert(inputRoles));
+ const auto aidlOutputRoles = NN_TRY(convert(outputRoles));
+
+ std::vector<IPreparedModelParcel> aidlPreparedModelParcels;
+ aidlPreparedModelParcels.reserve(aidlPreparedModels.size());
+ for (const auto& preparedModel : aidlPreparedModels) {
+ aidlPreparedModelParcels.push_back({.preparedModel = preparedModel});
+ }
+
+ DeviceBuffer buffer;
+ const auto ret = kDevice->allocate(aidlDesc, aidlPreparedModelParcels, aidlInputRoles,
+ aidlOutputRoles, &buffer);
+ HANDLE_ASTATUS(ret) << "IDevice::allocate failed";
+
+ if (buffer.token < 0) {
+ return NN_ERROR() << "IDevice::allocate returned negative token";
+ }
+
+ return Buffer::create(buffer.buffer, static_cast<nn::Request::MemoryDomainToken>(buffer.token));
+}
+
+DeathMonitor* Device::getDeathMonitor() const {
+ return kDeathHandler.getDeathMonitor().get();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/PreparedModel.cpp b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
new file mode 100644
index 0000000..aee4d90
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/PreparedModel.cpp
@@ -0,0 +1,172 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "PreparedModel.h"
+
+#include "Callbacks.h"
+#include "Conversions.h"
+#include "ProtectCallback.h"
+#include "Utils.h"
+
+#include <android/binder_auto_utils.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/Result.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/1.0/Burst.h>
+#include <nnapi/hal/CommonUtils.h>
+#include <nnapi/hal/HandleError.h>
+
+#include <memory>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+// See hardware/interfaces/neuralnetworks/utils/README.md for more information on AIDL interface
+// lifetimes across processes and for protecting asynchronous calls across AIDL.
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+nn::GeneralResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> convertExecutionResults(
+ const std::vector<OutputShape>& outputShapes, const Timing& timing) {
+ return std::make_pair(NN_TRY(nn::convert(outputShapes)), NN_TRY(nn::convert(timing)));
+}
+
+nn::GeneralResult<std::pair<nn::Timing, nn::Timing>> convertFencedExecutionResults(
+ ErrorStatus status, const aidl_hal::Timing& timingLaunched,
+ const aidl_hal::Timing& timingFenced) {
+ HANDLE_HAL_STATUS(status) << "fenced execution callback info failed with " << toString(status);
+ return std::make_pair(NN_TRY(nn::convert(timingLaunched)), NN_TRY(nn::convert(timingFenced)));
+}
+
+} // namespace
+
+nn::GeneralResult<std::shared_ptr<const PreparedModel>> PreparedModel::create(
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel) {
+ if (preparedModel == nullptr) {
+ return NN_ERROR()
+ << "aidl_hal::utils::PreparedModel::create must have non-null preparedModel";
+ }
+
+ return std::make_shared<const PreparedModel>(PrivateConstructorTag{}, std::move(preparedModel));
+}
+
+PreparedModel::PreparedModel(PrivateConstructorTag /*tag*/,
+ std::shared_ptr<aidl_hal::IPreparedModel> preparedModel)
+ : kPreparedModel(std::move(preparedModel)) {}
+
+nn::ExecutionResult<std::pair<std::vector<nn::OutputShape>, nn::Timing>> PreparedModel::execute(
+ const nn::Request& request, nn::MeasureTiming measure,
+ const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration) const {
+ // Ensure that request is ready for IPC.
+ std::optional<nn::Request> maybeRequestInShared;
+ const nn::Request& requestInShared = NN_TRY(hal::utils::makeExecutionFailure(
+ hal::utils::flushDataFromPointerToShared(&request, &maybeRequestInShared)));
+
+ const auto aidlRequest = NN_TRY(hal::utils::makeExecutionFailure(convert(requestInShared)));
+ const auto aidlMeasure = NN_TRY(hal::utils::makeExecutionFailure(convert(measure)));
+ const auto aidlDeadline = NN_TRY(hal::utils::makeExecutionFailure(convert(deadline)));
+ const auto aidlLoopTimeoutDuration =
+ NN_TRY(hal::utils::makeExecutionFailure(convert(loopTimeoutDuration)));
+
+ ExecutionResult executionResult;
+ const auto ret = kPreparedModel->executeSynchronously(
+ aidlRequest, aidlMeasure, aidlDeadline, aidlLoopTimeoutDuration, &executionResult);
+ HANDLE_ASTATUS(ret) << "executeSynchronously failed";
+ if (!executionResult.outputSufficientSize) {
+ auto canonicalOutputShapes =
+ nn::convert(executionResult.outputShapes).value_or(std::vector<nn::OutputShape>{});
+ return NN_ERROR(nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE, std::move(canonicalOutputShapes))
+ << "execution failed with " << nn::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
+ }
+ auto [outputShapes, timing] = NN_TRY(hal::utils::makeExecutionFailure(
+ convertExecutionResults(executionResult.outputShapes, executionResult.timing)));
+
+ NN_TRY(hal::utils::makeExecutionFailure(
+ hal::utils::unflushDataFromSharedToPointer(request, maybeRequestInShared)));
+
+ return std::make_pair(std::move(outputShapes), timing);
+}
+
+nn::GeneralResult<std::pair<nn::SyncFence, nn::ExecuteFencedInfoCallback>>
+PreparedModel::executeFenced(const nn::Request& request, const std::vector<nn::SyncFence>& waitFor,
+ nn::MeasureTiming measure, const nn::OptionalTimePoint& deadline,
+ const nn::OptionalDuration& loopTimeoutDuration,
+ const nn::OptionalDuration& timeoutDurationAfterFence) const {
+ // Ensure that request is ready for IPC.
+ std::optional<nn::Request> maybeRequestInShared;
+ const nn::Request& requestInShared =
+ NN_TRY(hal::utils::flushDataFromPointerToShared(&request, &maybeRequestInShared));
+
+ const auto aidlRequest = NN_TRY(convert(requestInShared));
+ const auto aidlWaitFor = NN_TRY(convert(waitFor));
+ const auto aidlMeasure = NN_TRY(convert(measure));
+ const auto aidlDeadline = NN_TRY(convert(deadline));
+ const auto aidlLoopTimeoutDuration = NN_TRY(convert(loopTimeoutDuration));
+ const auto aidlTimeoutDurationAfterFence = NN_TRY(convert(timeoutDurationAfterFence));
+
+ FencedExecutionResult result;
+ const auto ret = kPreparedModel->executeFenced(aidlRequest, aidlWaitFor, aidlMeasure,
+ aidlDeadline, aidlLoopTimeoutDuration,
+ aidlTimeoutDurationAfterFence, &result);
+ HANDLE_ASTATUS(ret) << "executeFenced failed";
+
+ auto resultSyncFence = nn::SyncFence::createAsSignaled();
+ if (result.syncFence.get() != -1) {
+ resultSyncFence = NN_TRY(nn::convert(result.syncFence));
+ }
+
+ auto callback = result.callback;
+ if (callback == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::GENERAL_FAILURE) << "callback is null";
+ }
+
+ // If executeFenced required the request memory to be moved into shared memory, block here until
+ // the fenced execution has completed and flush the memory back.
+ if (maybeRequestInShared.has_value()) {
+ const auto state = resultSyncFence.syncWait({});
+ if (state != nn::SyncFence::FenceState::SIGNALED) {
+ return NN_ERROR() << "syncWait failed with " << state;
+ }
+ NN_TRY(hal::utils::unflushDataFromSharedToPointer(request, maybeRequestInShared));
+ }
+
+ // Create callback which can be used to retrieve the execution error status and timings.
+ nn::ExecuteFencedInfoCallback resultCallback =
+ [callback]() -> nn::GeneralResult<std::pair<nn::Timing, nn::Timing>> {
+ ErrorStatus errorStatus;
+ Timing timingLaunched;
+ Timing timingFenced;
+ const auto ret = callback->getExecutionInfo(&timingLaunched, &timingFenced, &errorStatus);
+ HANDLE_ASTATUS(ret) << "fenced execution callback getExecutionInfo failed";
+ return convertFencedExecutionResults(errorStatus, timingLaunched, timingFenced);
+ };
+
+ return std::make_pair(std::move(resultSyncFence), std::move(resultCallback));
+}
+
+nn::GeneralResult<nn::SharedBurst> PreparedModel::configureExecutionBurst() const {
+ return hal::V1_0::utils::Burst::create(shared_from_this());
+}
+
+std::any PreparedModel::getUnderlyingResource() const {
+ std::shared_ptr<aidl_hal::IPreparedModel> resource = kPreparedModel;
+ return resource;
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/ProtectCallback.cpp b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp
new file mode 100644
index 0000000..124641c
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/ProtectCallback.cpp
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ProtectCallback.h"
+
+#include <android-base/logging.h>
+#include <android-base/scopeguard.h>
+#include <android-base/thread_annotations.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <nnapi/Result.h>
+#include <nnapi/hal/ProtectCallback.h>
+
+#include <algorithm>
+#include <functional>
+#include <memory>
+#include <mutex>
+#include <vector>
+
+#include "Utils.h"
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+void DeathMonitor::serviceDied() {
+ std::lock_guard guard(mMutex);
+ std::for_each(mObjects.begin(), mObjects.end(),
+ [](hal::utils::IProtectedCallback* killable) { killable->notifyAsDeadObject(); });
+}
+
+void DeathMonitor::serviceDied(void* cookie) {
+ auto deathMonitor = static_cast<DeathMonitor*>(cookie);
+ deathMonitor->serviceDied();
+}
+
+void DeathMonitor::add(hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ std::lock_guard guard(mMutex);
+ mObjects.push_back(killable);
+}
+
+void DeathMonitor::remove(hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ std::lock_guard guard(mMutex);
+ const auto removedIter = std::remove(mObjects.begin(), mObjects.end(), killable);
+ mObjects.erase(removedIter);
+}
+
+nn::GeneralResult<DeathHandler> DeathHandler::create(std::shared_ptr<ndk::ICInterface> object) {
+ if (object == nullptr) {
+ return NN_ERROR(nn::ErrorStatus::INVALID_ARGUMENT)
+ << "utils::DeathHandler::create must have non-null object";
+ }
+ auto deathMonitor = std::make_shared<DeathMonitor>();
+ auto deathRecipient = ndk::ScopedAIBinder_DeathRecipient(
+ AIBinder_DeathRecipient_new(DeathMonitor::serviceDied));
+
+ // If passed a local binder, AIBinder_linkToDeath will do nothing and return
+ // STATUS_INVALID_OPERATION. We ignore this case because we only use local binders in tests
+ // where this is not an error.
+ if (object->isRemote()) {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_linkToDeath(
+ object->asBinder().get(), deathRecipient.get(), deathMonitor.get()));
+ HANDLE_ASTATUS(ret) << "AIBinder_linkToDeath failed";
+ }
+
+ return DeathHandler(std::move(object), std::move(deathRecipient), std::move(deathMonitor));
+}
+
+DeathHandler::DeathHandler(std::shared_ptr<ndk::ICInterface> object,
+ ndk::ScopedAIBinder_DeathRecipient deathRecipient,
+ std::shared_ptr<DeathMonitor> deathMonitor)
+ : kObject(std::move(object)),
+ kDeathRecipient(std::move(deathRecipient)),
+ kDeathMonitor(std::move(deathMonitor)) {
+ CHECK(kObject != nullptr);
+ CHECK(kDeathRecipient.get() != nullptr);
+ CHECK(kDeathMonitor != nullptr);
+}
+
+DeathHandler::~DeathHandler() {
+ if (kObject != nullptr && kDeathRecipient.get() != nullptr && kDeathMonitor != nullptr) {
+ const auto ret = ndk::ScopedAStatus::fromStatus(AIBinder_unlinkToDeath(
+ kObject->asBinder().get(), kDeathRecipient.get(), kDeathMonitor.get()));
+ const auto maybeSuccess = handleTransportError(ret);
+ if (!maybeSuccess.ok()) {
+ LOG(ERROR) << maybeSuccess.error().message;
+ }
+ }
+}
+
+[[nodiscard]] ::android::base::ScopeGuard<DeathHandler::Cleanup> DeathHandler::protectCallback(
+ hal::utils::IProtectedCallback* killable) const {
+ CHECK(killable != nullptr);
+ kDeathMonitor->add(killable);
+ return ::android::base::make_scope_guard(
+ [deathMonitor = kDeathMonitor, killable] { deathMonitor->remove(killable); });
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Service.cpp b/neuralnetworks/aidl/utils/src/Service.cpp
new file mode 100644
index 0000000..5ec6ded
--- /dev/null
+++ b/neuralnetworks/aidl/utils/src/Service.cpp
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "Service.h"
+
+#include <android/binder_auto_utils.h>
+#include <android/binder_manager.h>
+
+#include <nnapi/IDevice.h>
+#include <nnapi/Result.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/ResilientDevice.h>
+#include <string>
+
+#include "Device.h"
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+nn::GeneralResult<nn::SharedDevice> getDevice(const std::string& name) {
+ hal::utils::ResilientDevice::Factory makeDevice =
+ [name](bool blocking) -> nn::GeneralResult<nn::SharedDevice> {
+ auto service = blocking ? IDevice::fromBinder(
+ ndk::SpAIBinder(AServiceManager_getService(name.c_str())))
+ : IDevice::fromBinder(ndk::SpAIBinder(
+ AServiceManager_checkService(name.c_str())));
+ if (service == nullptr) {
+ return NN_ERROR() << (blocking ? "AServiceManager_getService"
+ : "AServiceManager_checkService")
+ << " returned nullptr";
+ }
+ return Device::create(name, std::move(service));
+ };
+
+ return hal::utils::ResilientDevice::create(std::move(makeDevice));
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/src/Utils.cpp b/neuralnetworks/aidl/utils/src/Utils.cpp
index 8d00e59..95516c8 100644
--- a/neuralnetworks/aidl/utils/src/Utils.cpp
+++ b/neuralnetworks/aidl/utils/src/Utils.cpp
@@ -16,13 +16,12 @@
#include "Utils.h"
+#include <android/binder_status.h>
#include <nnapi/Result.h>
namespace aidl::android::hardware::neuralnetworks::utils {
namespace {
-using ::android::nn::GeneralResult;
-
template <typename Type>
nn::GeneralResult<std::vector<Type>> cloneVec(const std::vector<Type>& arguments) {
std::vector<Type> clonedObjects;
@@ -34,13 +33,13 @@
}
template <typename Type>
-GeneralResult<std::vector<Type>> clone(const std::vector<Type>& arguments) {
+nn::GeneralResult<std::vector<Type>> clone(const std::vector<Type>& arguments) {
return cloneVec(arguments);
}
} // namespace
-GeneralResult<Memory> clone(const Memory& memory) {
+nn::GeneralResult<Memory> clone(const Memory& memory) {
common::NativeHandle nativeHandle;
nativeHandle.ints = memory.handle.ints;
nativeHandle.fds.reserve(memory.handle.fds.size());
@@ -58,7 +57,7 @@
};
}
-GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool) {
+nn::GeneralResult<RequestMemoryPool> clone(const RequestMemoryPool& requestPool) {
using Tag = RequestMemoryPool::Tag;
switch (requestPool.getTag()) {
case Tag::pool:
@@ -70,10 +69,10 @@
// compiler.
return (NN_ERROR() << "Unrecognized request pool tag: " << requestPool.getTag())
.
- operator GeneralResult<RequestMemoryPool>();
+ operator nn::GeneralResult<RequestMemoryPool>();
}
-GeneralResult<Request> clone(const Request& request) {
+nn::GeneralResult<Request> clone(const Request& request) {
return Request{
.inputs = request.inputs,
.outputs = request.outputs,
@@ -81,7 +80,7 @@
};
}
-GeneralResult<Model> clone(const Model& model) {
+nn::GeneralResult<Model> clone(const Model& model) {
return Model{
.main = model.main,
.referenced = model.referenced,
@@ -92,4 +91,20 @@
};
}
+nn::GeneralResult<void> handleTransportError(const ndk::ScopedAStatus& ret) {
+ if (ret.getStatus() == STATUS_DEAD_OBJECT) {
+ return nn::error(nn::ErrorStatus::DEAD_OBJECT)
+ << "Binder transaction returned STATUS_DEAD_OBJECT: " << ret.getDescription();
+ }
+ if (ret.isOk()) {
+ return {};
+ }
+ if (ret.getExceptionCode() != EX_SERVICE_SPECIFIC) {
+ return nn::error(nn::ErrorStatus::GENERAL_FAILURE)
+ << "Binder transaction returned exception: " << ret.getDescription();
+ }
+ return nn::error(static_cast<nn::ErrorStatus>(ret.getServiceSpecificError()))
+ << ret.getMessage();
+}
+
} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/BufferTest.cpp b/neuralnetworks/aidl/utils/test/BufferTest.cpp
new file mode 100644
index 0000000..9736160
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/BufferTest.cpp
@@ -0,0 +1,212 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockBuffer.h"
+
+#include <aidl/android/hardware/neuralnetworks/ErrorStatus.h>
+#include <aidl/android/hardware/neuralnetworks/IBuffer.h>
+#include <android/binder_auto_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IBuffer.h>
+#include <nnapi/SharedMemory.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/Buffer.h>
+
+#include <functional>
+#include <memory>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+using ::testing::_;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::Return;
+
+const auto kMemory = nn::createSharedMemory(4).value();
+const std::shared_ptr<IBuffer> kInvalidBuffer;
+constexpr auto kInvalidToken = nn::Request::MemoryDomainToken{0};
+constexpr auto kToken = nn::Request::MemoryDomainToken{1};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+} // namespace
+
+TEST(BufferTest, invalidBuffer) {
+ // run test
+ const auto result = Buffer::create(kInvalidBuffer, kToken);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, invalidToken) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+
+ // run test
+ const auto result = Buffer::create(mockBuffer, kInvalidToken);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, create) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+
+ // run test
+ const auto token = buffer->getToken();
+
+ // verify result
+ EXPECT_EQ(token, kToken);
+}
+
+TEST(BufferTest, copyTo) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeStatusOk));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ EXPECT_TRUE(result.has_value()) << result.error().message;
+}
+
+TEST(BufferTest, copyToError) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyToTransportFailure) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyToDeadObject) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyTo(_)).Times(1).WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = buffer->copyTo(kMemory);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(BufferTest, copyFrom) {
+ // setup call
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _)).Times(1).WillOnce(InvokeWithoutArgs(makeStatusOk));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ EXPECT_TRUE(result.has_value());
+}
+
+TEST(BufferTest, copyFromError) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyFromTransportFailure) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(BufferTest, copyFromDeadObject) {
+ // setup test
+ const auto mockBuffer = MockBuffer::create();
+ const auto buffer = Buffer::create(mockBuffer, kToken).value();
+ EXPECT_CALL(*mockBuffer, copyFrom(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = buffer->copyFrom(kMemory, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/DeviceTest.cpp b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
new file mode 100644
index 0000000..e53b0a8
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/DeviceTest.cpp
@@ -0,0 +1,861 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockBuffer.h"
+#include "MockDevice.h"
+#include "MockPreparedModel.h"
+
+#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_status.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IDevice.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/Device.h>
+
+#include <functional>
+#include <memory>
+#include <string>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+namespace nn = ::android::nn;
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::SetArgPointee;
+
+const nn::Model kSimpleModel = {
+ .main = {.operands = {{.type = nn::OperandType::TENSOR_FLOAT32,
+ .dimensions = {1},
+ .lifetime = nn::Operand::LifeTime::SUBGRAPH_INPUT},
+ {.type = nn::OperandType::TENSOR_FLOAT32,
+ .dimensions = {1},
+ .lifetime = nn::Operand::LifeTime::SUBGRAPH_OUTPUT}},
+ .operations = {{.type = nn::OperationType::RELU, .inputs = {0}, .outputs = {1}}},
+ .inputIndexes = {0},
+ .outputIndexes = {1}}};
+
+const std::string kName = "Google-MockV1";
+const std::string kInvalidName = "";
+const std::shared_ptr<BnDevice> kInvalidDevice;
+constexpr PerformanceInfo kNoPerformanceInfo = {.execTime = std::numeric_limits<float>::max(),
+ .powerUsage = std::numeric_limits<float>::max()};
+constexpr NumberOfCacheFiles kNumberOfCacheFiles = {.numModelCache = nn::kMaxNumberOfCacheFiles,
+ .numDataCache = nn::kMaxNumberOfCacheFiles};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+std::shared_ptr<MockDevice> createMockDevice() {
+ const auto mockDevice = MockDevice::create();
+
+ // Setup default actions for each relevant call.
+ ON_CALL(*mockDevice, getVersionString(_))
+ .WillByDefault(DoAll(SetArgPointee<0>(kName), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getType(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(DeviceType::OTHER), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getSupportedExtensions(_))
+ .WillByDefault(DoAll(SetArgPointee<0>(std::vector<Extension>{}),
+ InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(kNumberOfCacheFiles), InvokeWithoutArgs(makeStatusOk)));
+ ON_CALL(*mockDevice, getCapabilities(_))
+ .WillByDefault(
+ DoAll(SetArgPointee<0>(Capabilities{
+ .relaxedFloat32toFloat16PerformanceScalar = kNoPerformanceInfo,
+ .relaxedFloat32toFloat16PerformanceTensor = kNoPerformanceInfo,
+ .ifPerformance = kNoPerformanceInfo,
+ .whilePerformance = kNoPerformanceInfo,
+ }),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // These EXPECT_CALL(...).Times(testing::AnyNumber()) calls are to suppress warnings on the
+ // uninteresting methods calls.
+ EXPECT_CALL(*mockDevice, getVersionString(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getType(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(testing::AnyNumber());
+ EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(testing::AnyNumber());
+
+ return mockDevice;
+}
+
+constexpr auto makePreparedModelReturnImpl =
+ [](ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel,
+ const std::shared_ptr<IPreparedModelCallback>& cb) {
+ cb->notify(returnStatus, preparedModel);
+ if (launchStatus == ErrorStatus::NONE) {
+ return ndk::ScopedAStatus::ok();
+ }
+ return ndk::ScopedAStatus::fromServiceSpecificError(static_cast<int32_t>(launchStatus));
+ };
+
+auto makePreparedModelReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel) {
+ return [launchStatus, returnStatus, preparedModel](
+ const Model& /*model*/, ExecutionPreference /*preference*/,
+ Priority /*priority*/, const int64_t& /*deadline*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
+ const std::vector<uint8_t>& /*token*/,
+ const std::shared_ptr<IPreparedModelCallback>& cb) -> ndk::ScopedAStatus {
+ return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
+ };
+}
+
+auto makePreparedModelFromCacheReturn(ErrorStatus launchStatus, ErrorStatus returnStatus,
+ const std::shared_ptr<MockPreparedModel>& preparedModel) {
+ return [launchStatus, returnStatus, preparedModel](
+ const int64_t& /*deadline*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*modelCache*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*dataCache*/,
+ const std::vector<uint8_t>& /*token*/,
+ const std::shared_ptr<IPreparedModelCallback>& cb) {
+ return makePreparedModelReturnImpl(launchStatus, returnStatus, preparedModel, cb);
+ };
+}
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+} // namespace
+
+TEST(DeviceTest, invalidName) {
+ // run test
+ const auto device = MockDevice::create();
+ const auto result = Device::create(kInvalidName, device);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
+}
+
+TEST(DeviceTest, invalidDevice) {
+ // run test
+ const auto result = Device::create(kName, kInvalidDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::INVALID_ARGUMENT);
+}
+
+TEST(DeviceTest, getVersionStringError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getVersionStringTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getVersionStringDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getTypeError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_)).Times(1).WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getTypeTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getTypeDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getType(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getSupportedExtensionsError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedExtensionsTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedExtensionsDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, dataCacheFilesExceedsSpecifiedMax) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
+ .numModelCache = nn::kMaxNumberOfCacheFiles + 1,
+ .numDataCache = nn::kMaxNumberOfCacheFiles}),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, modelCacheFilesExceedsSpecifiedMax) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(NumberOfCacheFiles{
+ .numModelCache = nn::kMaxNumberOfCacheFiles,
+ .numDataCache = nn::kMaxNumberOfCacheFiles + 1}),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getNumberOfCacheFilesNeededDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getCapabilitiesError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getCapabilitiesTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getCapabilitiesDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getCapabilities(_))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = Device::create(kName, mockDevice);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, getName) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+
+ // run test
+ const auto& name = device->getName();
+
+ // verify result
+ EXPECT_EQ(name, kName);
+}
+
+TEST(DeviceTest, getFeatureLevel) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+
+ // run test
+ const auto featureLevel = device->getFeatureLevel();
+
+ // verify result
+ EXPECT_EQ(featureLevel, nn::Version::ANDROID_S);
+}
+
+TEST(DeviceTest, getCachedData) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ EXPECT_CALL(*mockDevice, getVersionString(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getType(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getSupportedExtensions(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getNumberOfCacheFilesNeeded(_)).Times(1);
+ EXPECT_CALL(*mockDevice, getCapabilities(_)).Times(1);
+
+ const auto result = Device::create(kName, mockDevice);
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& device = result.value();
+
+ // run test and verify results
+ EXPECT_EQ(device->getVersionString(), device->getVersionString());
+ EXPECT_EQ(device->getType(), device->getType());
+ EXPECT_EQ(device->getSupportedExtensions(), device->getSupportedExtensions());
+ EXPECT_EQ(device->getNumberOfCacheFilesNeeded(), device->getNumberOfCacheFilesNeeded());
+ EXPECT_EQ(device->getCapabilities(), device->getCapabilities());
+}
+
+TEST(DeviceTest, getSupportedOperations) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(DoAll(
+ SetArgPointee<1>(std::vector<bool>(kSimpleModel.main.operations.size(), true)),
+ InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& supportedOperations = result.value();
+ EXPECT_EQ(supportedOperations.size(), kSimpleModel.main.operations.size());
+ EXPECT_THAT(supportedOperations, Each(testing::IsTrue()));
+}
+
+TEST(DeviceTest, getSupportedOperationsError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedOperationsTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, getSupportedOperationsDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, getSupportedOperations(_, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->getSupportedOperations(kSimpleModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModel) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockPreparedModel = MockPreparedModel::create();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ mockPreparedModel)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, prepareModelLaunchError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::GENERAL_FAILURE,
+ ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelReturn(ErrorStatus::NONE,
+ ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelNullptrError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(
+ Invoke(makePreparedModelReturn(ErrorStatus::NONE, ErrorStatus::NONE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelAsyncCrash) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto ret = [&device]() {
+ DeathMonitor::serviceDied(device->getDeathMonitor());
+ return ndk::ScopedAStatus::ok();
+ };
+ EXPECT_CALL(*mockDevice, prepareModel(_, _, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(ret));
+
+ // run test
+ const auto result = device->prepareModel(kSimpleModel, nn::ExecutionPreference::DEFAULT,
+ nn::Priority::DEFAULT, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelFromCache) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockPreparedModel = MockPreparedModel::create();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ mockPreparedModel)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, prepareModelFromCacheLaunchError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ ErrorStatus::GENERAL_FAILURE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheReturnError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(
+ ErrorStatus::NONE, ErrorStatus::GENERAL_FAILURE, nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheNullptrError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makePreparedModelFromCacheReturn(ErrorStatus::NONE, ErrorStatus::NONE,
+ nullptr)));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, prepareModelFromCacheDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, prepareModelFromCacheAsyncCrash) {
+ // setup test
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto ret = [&device]() {
+ DeathMonitor::serviceDied(device->getDeathMonitor());
+ return ndk::ScopedAStatus::ok();
+ };
+ EXPECT_CALL(*mockDevice, prepareModelFromCache(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(ret));
+
+ // run test
+ const auto result = device->prepareModelFromCache({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(DeviceTest, allocate) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ const auto mockBuffer = DeviceBuffer{.buffer = MockBuffer::create(), .token = 1};
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<4>(mockBuffer), InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ EXPECT_NE(result.value(), nullptr);
+}
+
+TEST(DeviceTest, allocateError) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, allocateTransportFailure) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(DeviceTest, allocateDeadObject) {
+ // setup call
+ const auto mockDevice = createMockDevice();
+ const auto device = Device::create(kName, mockDevice).value();
+ EXPECT_CALL(*mockDevice, allocate(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = device->allocate({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/aidl/utils/test/MockBuffer.h b/neuralnetworks/aidl/utils/test/MockBuffer.h
new file mode 100644
index 0000000..5746176
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockBuffer.h
@@ -0,0 +1,43 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
+
+#include <aidl/android/hardware/neuralnetworks/BnBuffer.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockBuffer final : public BnBuffer {
+ public:
+ static std::shared_ptr<MockBuffer> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, copyTo, (const Memory& dst), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, copyFrom,
+ (const Memory& src, const std::vector<int32_t>& dimensions), (override));
+};
+
+inline std::shared_ptr<MockBuffer> MockBuffer::create() {
+ return ndk::SharedRefBase::make<MockBuffer>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_BUFFER
diff --git a/neuralnetworks/aidl/utils/test/MockDevice.h b/neuralnetworks/aidl/utils/test/MockDevice.h
new file mode 100644
index 0000000..9b35bf8
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockDevice.h
@@ -0,0 +1,67 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
+
+#include <aidl/android/hardware/neuralnetworks/BnDevice.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockDevice final : public BnDevice {
+ public:
+ static std::shared_ptr<MockDevice> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, allocate,
+ (const BufferDesc& desc, const std::vector<IPreparedModelParcel>& preparedModels,
+ const std::vector<BufferRole>& inputRoles,
+ const std::vector<BufferRole>& outputRoles, DeviceBuffer* deviceBuffer),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getCapabilities, (Capabilities * capabilities), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getNumberOfCacheFilesNeeded,
+ (NumberOfCacheFiles * numberOfCacheFiles), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getSupportedExtensions, (std::vector<Extension> * extensions),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getSupportedOperations,
+ (const Model& model, std::vector<bool>* supportedOperations), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getType, (DeviceType * deviceType), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, getVersionString, (std::string * version), (override));
+ MOCK_METHOD(ndk::ScopedAStatus, prepareModel,
+ (const Model& model, ExecutionPreference preference, Priority priority,
+ int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
+ const std::vector<ndk::ScopedFileDescriptor>& dataCache,
+ const std::vector<uint8_t>& token,
+ const std::shared_ptr<IPreparedModelCallback>& callback),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, prepareModelFromCache,
+ (int64_t deadline, const std::vector<ndk::ScopedFileDescriptor>& modelCache,
+ const std::vector<ndk::ScopedFileDescriptor>& dataCache,
+ const std::vector<uint8_t>& token,
+ const std::shared_ptr<IPreparedModelCallback>& callback),
+ (override));
+};
+
+inline std::shared_ptr<MockDevice> MockDevice::create() {
+ return ndk::SharedRefBase::make<MockDevice>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_DEVICE
diff --git a/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h b/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h
new file mode 100644
index 0000000..463e1c9
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockFencedExecutionCallback.h
@@ -0,0 +1,45 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
+
+#include <aidl/android/hardware/neuralnetworks/BnFencedExecutionCallback.h>
+#include <android/binder_auto_utils.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockFencedExecutionCallback final : public BnFencedExecutionCallback {
+ public:
+ static std::shared_ptr<MockFencedExecutionCallback> create();
+
+ // V1_3 methods below.
+ MOCK_METHOD(ndk::ScopedAStatus, getExecutionInfo,
+ (Timing * timingLaunched, Timing* timingFenced, ErrorStatus* errorStatus),
+ (override));
+};
+
+inline std::shared_ptr<MockFencedExecutionCallback> MockFencedExecutionCallback::create() {
+ return ndk::SharedRefBase::make<MockFencedExecutionCallback>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_FENCED_EXECUTION_CALLBACK
diff --git a/neuralnetworks/aidl/utils/test/MockPreparedModel.h b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
new file mode 100644
index 0000000..545b491
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/MockPreparedModel.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
+#define ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
+
+#include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h>
+#include <android/binder_interface_utils.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <hidl/HidlSupport.h>
+#include <hidl/Status.h>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+
+class MockPreparedModel final : public BnPreparedModel {
+ public:
+ static std::shared_ptr<MockPreparedModel> create();
+
+ MOCK_METHOD(ndk::ScopedAStatus, executeSynchronously,
+ (const Request& request, bool measureTiming, int64_t deadline,
+ int64_t loopTimeoutDuration, ExecutionResult* executionResult),
+ (override));
+ MOCK_METHOD(ndk::ScopedAStatus, executeFenced,
+ (const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
+ bool measureTiming, int64_t deadline, int64_t loopTimeoutDuration,
+ int64_t duration, FencedExecutionResult* fencedExecutionResult),
+ (override));
+};
+
+inline std::shared_ptr<MockPreparedModel> MockPreparedModel::create() {
+ return ndk::SharedRefBase::make<MockPreparedModel>();
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
+
+#endif // ANDROID_HARDWARE_INTERFACES_NEURALNETWORKS_AIDL_UTILS_TEST_MOCK_PREPARED_MODEL
diff --git a/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
new file mode 100644
index 0000000..7e28861
--- /dev/null
+++ b/neuralnetworks/aidl/utils/test/PreparedModelTest.cpp
@@ -0,0 +1,272 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockFencedExecutionCallback.h"
+#include "MockPreparedModel.h"
+
+#include <aidl/android/hardware/neuralnetworks/IFencedExecutionCallback.h>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <nnapi/IPreparedModel.h>
+#include <nnapi/TypeUtils.h>
+#include <nnapi/Types.h>
+#include <nnapi/hal/aidl/PreparedModel.h>
+
+#include <functional>
+#include <memory>
+
+namespace aidl::android::hardware::neuralnetworks::utils {
+namespace {
+
+using ::testing::_;
+using ::testing::DoAll;
+using ::testing::Invoke;
+using ::testing::InvokeWithoutArgs;
+using ::testing::SetArgPointee;
+
+const std::shared_ptr<IPreparedModel> kInvalidPreparedModel;
+constexpr auto kNoTiming = Timing{.timeOnDevice = -1, .timeInDriver = -1};
+
+constexpr auto makeStatusOk = [] { return ndk::ScopedAStatus::ok(); };
+
+constexpr auto makeGeneralFailure = [] {
+ return ndk::ScopedAStatus::fromServiceSpecificError(
+ static_cast<int32_t>(ErrorStatus::GENERAL_FAILURE));
+};
+constexpr auto makeGeneralTransportFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_NO_MEMORY);
+};
+constexpr auto makeDeadObjectFailure = [] {
+ return ndk::ScopedAStatus::fromStatus(STATUS_DEAD_OBJECT);
+};
+
+auto makeFencedExecutionResult(const std::shared_ptr<MockFencedExecutionCallback>& callback) {
+ return [callback](const Request& /*request*/,
+ const std::vector<ndk::ScopedFileDescriptor>& /*waitFor*/,
+ bool /*measureTiming*/, int64_t /*deadline*/, int64_t /*loopTimeoutDuration*/,
+ int64_t /*duration*/, FencedExecutionResult* fencedExecutionResult) {
+ *fencedExecutionResult = FencedExecutionResult{.callback = callback,
+ .syncFence = ndk::ScopedFileDescriptor(-1)};
+ return ndk::ScopedAStatus::ok();
+ };
+}
+
+} // namespace
+
+TEST(PreparedModelTest, invalidPreparedModel) {
+ // run test
+ const auto result = PreparedModel::create(kInvalidPreparedModel);
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSync) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockExecutionResult = ExecutionResult{
+ .outputSufficientSize = true,
+ .outputShapes = {},
+ .timing = kNoTiming,
+ };
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(
+ DoAll(SetArgPointee<4>(mockExecutionResult), InvokeWithoutArgs(makeStatusOk)));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ EXPECT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+}
+
+TEST(PreparedModelTest, executeSyncError) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeGeneralFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSyncTransportFailure) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeSyncDeadObject) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeSynchronously(_, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = preparedModel->execute({}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+TEST(PreparedModelTest, executeFenced) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::NONE), Invoke(makeStatusOk)));
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_EQ(syncFence.syncWait({}), nn::SyncFence::FenceState::SIGNALED);
+ ASSERT_NE(callback, nullptr);
+
+ // get results from callback
+ const auto callbackResult = callback();
+ ASSERT_TRUE(callbackResult.has_value()) << "Failed with " << callbackResult.error().code << ": "
+ << callbackResult.error().message;
+}
+
+TEST(PreparedModelTest, executeFencedCallbackError) {
+ // setup call
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ const auto mockCallback = MockFencedExecutionCallback::create();
+ EXPECT_CALL(*mockCallback, getExecutionInfo(_, _, _))
+ .Times(1)
+ .WillOnce(Invoke(DoAll(SetArgPointee<0>(kNoTiming), SetArgPointee<1>(kNoTiming),
+ SetArgPointee<2>(ErrorStatus::GENERAL_FAILURE),
+ Invoke(makeStatusOk))));
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(Invoke(makeFencedExecutionResult(mockCallback)));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_TRUE(result.has_value())
+ << "Failed with " << result.error().code << ": " << result.error().message;
+ const auto& [syncFence, callback] = result.value();
+ EXPECT_NE(syncFence.syncWait({}), nn::SyncFence::FenceState::ACTIVE);
+ ASSERT_NE(callback, nullptr);
+
+ // verify callback failure
+ const auto callbackResult = callback();
+ ASSERT_FALSE(callbackResult.has_value());
+ EXPECT_EQ(callbackResult.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedError) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedTransportFailure) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeGeneralTransportFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::GENERAL_FAILURE);
+}
+
+TEST(PreparedModelTest, executeFencedDeadObject) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+ EXPECT_CALL(*mockPreparedModel, executeFenced(_, _, _, _, _, _, _))
+ .Times(1)
+ .WillOnce(InvokeWithoutArgs(makeDeadObjectFailure));
+
+ // run test
+ const auto result = preparedModel->executeFenced({}, {}, {}, {}, {}, {});
+
+ // verify result
+ ASSERT_FALSE(result.has_value());
+ EXPECT_EQ(result.error().code, nn::ErrorStatus::DEAD_OBJECT);
+}
+
+// TODO: test burst execution if/when it is added to nn::IPreparedModel.
+
+TEST(PreparedModelTest, getUnderlyingResource) {
+ // setup test
+ const auto mockPreparedModel = MockPreparedModel::create();
+ const auto preparedModel = PreparedModel::create(mockPreparedModel).value();
+
+ // run test
+ const auto resource = preparedModel->getUnderlyingResource();
+
+ // verify resource
+ const std::shared_ptr<IPreparedModel>* maybeMock =
+ std::any_cast<std::shared_ptr<IPreparedModel>>(&resource);
+ ASSERT_NE(maybeMock, nullptr);
+ EXPECT_EQ(maybeMock->get(), mockPreparedModel.get());
+}
+
+} // namespace aidl::android::hardware::neuralnetworks::utils
diff --git a/neuralnetworks/utils/README.md b/neuralnetworks/utils/README.md
index 45ca0b4..87b3f9f 100644
--- a/neuralnetworks/utils/README.md
+++ b/neuralnetworks/utils/README.md
@@ -49,7 +49,9 @@
(i.e., not as a nested class) or used in a subsequent version of the NN HAL. Prefer using `convert`
over `unvalidatedConvert`.
-# HIDL Interface Lifetimes across Processes
+# Interface Lifetimes across Processes
+
+## HIDL
Some notes about HIDL interface objects and lifetimes across processes:
@@ -68,7 +70,20 @@
If the process which created the HIDL interface object dies, any call on this object from another
process will result in a HIDL transport error with the code `DEAD_OBJECT`.
-# Protecting Asynchronous Calls across HIDL
+## AIDL
+
+We use NDK backend for AIDL interfaces. Handling of lifetimes is generally the same with the
+following differences:
+* Interfaces inherit from `ndk::ICInterface`, which inherits from `ndk::SharedRefBase`. The latter
+ is an analog of `::android::RefBase` using `std::shared_ptr` for reference counting.
+* AIDL calls return `ndk::ScopedAStatus` which wraps fields of types `binder_status_t` and
+ `binder_exception_t`. In case the call is made on a dead object, the call will return
+ `ndk::ScopedAStatus` with exception `EX_TRANSACTION_FAILED` and binder status
+ `STATUS_DEAD_OBJECT`.
+
+# Protecting Asynchronous Calls
+
+## Across HIDL
Some notes about asynchronous calls across HIDL:
@@ -95,3 +110,17 @@
driver process has died, and `DeathHandler` will unblock any thread waiting on the results of an
`IProtectedCallback` callback object that may otherwise not be signaled. In order for this to work,
the `IProtectedCallback` object must have been registered via `DeathHandler::protectCallback()`.
+
+## Across AIDL
+
+We use NDK backend for AIDL interfaces. Handling of asynchronous calls is generally the same with
+the following differences:
+* AIDL calls return `ndk::ScopedAStatus` which wraps fields of types `binder_status_t` and
+ `binder_exception_t`. In case the call is made on a dead object, the call will return
+ `ndk::ScopedAStatus` with exception `EX_TRANSACTION_FAILED` and binder status
+ `STATUS_DEAD_OBJECT`.
+* AIDL interface doesn't contain asynchronous `IPreparedModel::execute`.
+* Service death is handled using `AIBinder_DeathRecipient` object which is linked to an interface
+ object using `AIBinder_linkToDeath`. nnapi/hal/aidl/ProtectCallback.h provides `DeathHandler`
+ object that is a direct analog of HIDL `DeathHandler`, only using libbinder_ndk objects for
+ implementation.
diff --git a/neuralnetworks/utils/common/Android.bp b/neuralnetworks/utils/common/Android.bp
index 6162fe8..2ed1e40 100644
--- a/neuralnetworks/utils/common/Android.bp
+++ b/neuralnetworks/utils/common/Android.bp
@@ -35,8 +35,10 @@
"neuralnetworks_types",
],
shared_libs: [
+ "android.hardware.neuralnetworks-V1-ndk_platform",
"libhidlbase",
"libnativewindow",
+ "libbinder_ndk",
],
}
diff --git a/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h b/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
index 2f6112a..8fe6b90 100644
--- a/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
+++ b/neuralnetworks/utils/common/include/nnapi/hal/CommonUtils.h
@@ -32,6 +32,8 @@
// Shorthands
namespace aidl::android::hardware::neuralnetworks {
namespace aidl_hal = ::aidl::android::hardware::neuralnetworks;
+namespace hal = ::android::hardware::neuralnetworks;
+namespace nn = ::android::nn;
} // namespace aidl::android::hardware::neuralnetworks
// Shorthands
diff --git a/power/stats/aidl/OWNERS b/power/stats/aidl/OWNERS
new file mode 100644
index 0000000..b290b49
--- /dev/null
+++ b/power/stats/aidl/OWNERS
@@ -0,0 +1,3 @@
+bsschwar@google.com
+krossmo@google.com
+tstrudel@google.com
diff --git a/power/stats/aidl/android/hardware/power/stats/IPowerStats.aidl b/power/stats/aidl/android/hardware/power/stats/IPowerStats.aidl
index 7a95f74..edc43ea 100644
--- a/power/stats/aidl/android/hardware/power/stats/IPowerStats.aidl
+++ b/power/stats/aidl/android/hardware/power/stats/IPowerStats.aidl
@@ -32,7 +32,7 @@
* A PowerEntity is defined as a platform subsystem, peripheral, or power domain that impacts
* the total device power consumption.
*
- * @return List of information on each PowerEntity
+ * @return List of information on each PowerEntity for which state residency can be requested.
*/
PowerEntity[] getPowerEntityInfo();
@@ -52,11 +52,12 @@
* Passing an empty list will return state residency for all available PowerEntitys.
* ID of each PowerEntity is contained in PowerEntityInfo.
*
- * @return StateResidency since boot for each requested PowerEntity
+ * @return StateResidencyResults since boot for each requested and available PowerEntity. Note
+ * that StateResidencyResult for a given PowerEntity may not always be available. Clients shall
+ * not rely on StateResidencyResult always being returned for every request.
*
- * Returns the following service-specific exceptions in order of highest priority:
- * - STATUS_BAD_VALUE if an invalid powerEntityId is provided
- * - STATUS_FAILED_TRANSACTION if any StateResidencyResult fails to be returned
+ * Returns the following exception codes:
+ * - EX_ILLEGAL_ARGUMENT if an invalid powerEntityId is provided
*/
StateResidencyResult[] getStateResidency(in int[] powerEntityIds);
@@ -66,7 +67,7 @@
* An EnergyConsumer is a device subsystem or peripheral that consumes energy. Energy
* consumption data may be used by framework for the purpose of power attribution.
*
- * @return List of EnergyConsumers that are available.
+ * @return List of EnergyConsumers for which energy consumption can be requested.
*/
EnergyConsumer[] getEnergyConsumerInfo();
@@ -74,38 +75,40 @@
* Reports the energy consumed since boot by each requested EnergyConsumer.
*
* @param energyConsumerIds List of IDs of EnergyConsumers for which data is requested.
- * Passing an empty list will return state residency for all available EnergyConsumers.
+ * Passing an empty list will return results for all available EnergyConsumers.
*
- * @return Energy consumed since boot for each requested EnergyConsumer
+ * @return Energy consumed since boot for each requested and available EnergyConsumer. Note
+ * that EnergyConsumerResult for a given EnergyConsumer may not always be available. Clients
+ * shall not rely on EnergyConsumerResult always being returned for every request.
*
- * Returns the following service-specific exceptions in order of highest priority:
- * - STATUS_BAD_VALUE if an invalid energyConsumerId is provided
- * - STATUS_FAILED_TRANSACTION if any EnergyConsumerResult fails to be returned
+ * Returns the following exception codes:
+ * - EX_ILLEGAL_ARGUMENT if an invalid energyConsumerId is provided
*/
EnergyConsumerResult[] getEnergyConsumed(in int[] energyConsumerIds);
/**
- * Return information related to all channels monitored by Energy Meters.
+ * Return information related to all Channels monitored by Energy Meters.
*
* An Energy Meter is a device that monitors energy and may support monitoring multiple
* channels simultaneously. A channel may correspond a bus, sense resistor, or power rail.
*
- * @return Channels monitored by Energy Meters.
+ * @return All Channels for which energy measurements can be requested.
*/
Channel[] getEnergyMeterInfo();
/**
- * Reports accumulated energy for each specified channel.
+ * Reports accumulated energy for each specified Channel.
*
* @param channelIds IDs of channels for which data is requested.
* Passing an empty list will return energy measurements for all available channels.
* ID of each channel is contained in ChannelInfo.
*
- * @return Energy measured since boot for each requested channel
+ * @return Energy measured since boot for each requested and available Channel. Note
+ * that EnergyMeasurement for a given Channel may not always be available. Clients
+ * shall not rely on EnergyMeasurement always being returned for every request.
*
- * Returns the following service-specific exceptions in order of highest priority:
- * - STATUS_BAD_VALUE if an invalid channelId is provided
- * - STATUS_FAILED_TRANSACTION if any EnergyMeasurement fails to be returned
+ * Returns the following exception codes:
+ * - EX_ILLEGAL_ARGUMENT if an invalid channelId is provided
*/
EnergyMeasurement[] readEnergyMeter(in int[] channelIds);
}
diff --git a/power/stats/aidl/default/FakeEnergyConsumer.h b/power/stats/aidl/default/FakeEnergyConsumer.h
new file mode 100644
index 0000000..f41aa6e
--- /dev/null
+++ b/power/stats/aidl/default/FakeEnergyConsumer.h
@@ -0,0 +1,83 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <PowerStats.h>
+
+#include <android-base/chrono_utils.h>
+
+#include <chrono>
+#include <random>
+
+namespace aidl {
+namespace android {
+namespace hardware {
+namespace power {
+namespace stats {
+
+class FakeEnergyConsumer : public PowerStats::IEnergyConsumer {
+ public:
+ FakeEnergyConsumer(EnergyConsumerType type, std::string name) : mType(type), mName(name) {
+ mResult.timestampMs = 0;
+ mResult.energyUWs = 0;
+ mResult.attribution = {};
+ }
+
+ ~FakeEnergyConsumer() = default;
+
+ std::string getName() override { return mName; }
+
+ EnergyConsumerType getType() override { return mType; }
+
+ std::optional<EnergyConsumerResult> getEnergyConsumed() override {
+ mFakeEnergyConsumerResult.update(&mResult);
+ return mResult;
+ }
+
+ private:
+ class FakeEnergyConsumerResult {
+ public:
+ FakeEnergyConsumerResult() : mDistribution(1, 100) {}
+ void update(EnergyConsumerResult* result) {
+ // generates number in the range 1..100
+ auto randNum = std::bind(mDistribution, mGenerator);
+
+ // Get current time since boot in milliseconds
+ uint64_t now = std::chrono::time_point_cast<std::chrono::milliseconds>(
+ ::android::base::boot_clock::now())
+ .time_since_epoch()
+ .count();
+ result->timestampMs = now;
+ result->energyUWs += randNum() * 100;
+ }
+
+ private:
+ std::default_random_engine mGenerator;
+ std::uniform_int_distribution<int> mDistribution;
+ };
+
+ EnergyConsumerType mType;
+ std::string mName;
+ FakeEnergyConsumerResult mFakeEnergyConsumerResult;
+ EnergyConsumerResult mResult;
+};
+
+} // namespace stats
+} // namespace power
+} // namespace hardware
+} // namespace android
+} // namespace aidl
\ No newline at end of file
diff --git a/power/stats/aidl/default/FakeEnergyMeter.h b/power/stats/aidl/default/FakeEnergyMeter.h
new file mode 100644
index 0000000..56dcdcc
--- /dev/null
+++ b/power/stats/aidl/default/FakeEnergyMeter.h
@@ -0,0 +1,112 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <PowerStats.h>
+
+#include <android-base/chrono_utils.h>
+
+#include <chrono>
+#include <random>
+
+namespace aidl {
+namespace android {
+namespace hardware {
+namespace power {
+namespace stats {
+
+class FakeEnergyMeter : public PowerStats::IEnergyMeter {
+ public:
+ FakeEnergyMeter(std::vector<std::pair<std::string, std::string>> channelNames) {
+ int32_t channelId = 0;
+ for (const auto& [name, subsystem] : channelNames) {
+ Channel c;
+ c.id = channelId++;
+ c.name = name;
+ c.subsystem = subsystem;
+
+ EnergyMeasurement m;
+ m.id = c.id;
+ m.timestampMs = 0;
+ m.durationMs = 0;
+ m.energyUWs = 0;
+
+ mChannels.push_back(c);
+ mEnergyMeasurements.push_back(m);
+ }
+ }
+ ~FakeEnergyMeter() = default;
+ ndk::ScopedAStatus readEnergyMeter(const std::vector<int32_t>& in_channelIds,
+ std::vector<EnergyMeasurement>* _aidl_return) override {
+ for (auto& measurement : mEnergyMeasurements) {
+ mFakeEnergyMeasurement.update(&measurement);
+ }
+
+ if (in_channelIds.empty()) {
+ *_aidl_return = mEnergyMeasurements;
+ } else {
+ for (int32_t id : in_channelIds) {
+ // check for invalid ids
+ if (id < 0 || id >= mEnergyMeasurements.size()) {
+ return ndk::ScopedAStatus(AStatus_fromExceptionCode(EX_ILLEGAL_ARGUMENT));
+ }
+
+ _aidl_return->push_back(mEnergyMeasurements[id]);
+ }
+ }
+
+ return ndk::ScopedAStatus::ok();
+ }
+
+ ndk::ScopedAStatus getEnergyMeterInfo(std::vector<Channel>* _aidl_return) override {
+ *_aidl_return = mChannels;
+ return ndk::ScopedAStatus::ok();
+ }
+
+ private:
+ class FakeEnergyMeasurement {
+ public:
+ FakeEnergyMeasurement() : mDistribution(1, 100) {}
+ void update(EnergyMeasurement* measurement) {
+ // generates number in the range 1..100
+ auto randNum = std::bind(mDistribution, mGenerator);
+
+ // Get current time since boot in milliseconds
+ uint64_t now = std::chrono::time_point_cast<std::chrono::milliseconds>(
+ ::android::base::boot_clock::now())
+ .time_since_epoch()
+ .count();
+ measurement->timestampMs = now;
+ measurement->durationMs = now;
+ measurement->energyUWs += randNum() * 100;
+ }
+
+ private:
+ std::default_random_engine mGenerator;
+ std::uniform_int_distribution<int> mDistribution;
+ };
+
+ std::vector<Channel> mChannels;
+ FakeEnergyMeasurement mFakeEnergyMeasurement;
+ std::vector<EnergyMeasurement> mEnergyMeasurements;
+};
+
+} // namespace stats
+} // namespace power
+} // namespace hardware
+} // namespace android
+} // namespace aidl
\ No newline at end of file
diff --git a/power/stats/aidl/default/FakeStateResidencyDataProvider.h b/power/stats/aidl/default/FakeStateResidencyDataProvider.h
new file mode 100644
index 0000000..2eeab61
--- /dev/null
+++ b/power/stats/aidl/default/FakeStateResidencyDataProvider.h
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <PowerStats.h>
+
+#include <random>
+
+namespace aidl {
+namespace android {
+namespace hardware {
+namespace power {
+namespace stats {
+
+class FakeStateResidencyDataProvider : public PowerStats::IStateResidencyDataProvider {
+ public:
+ FakeStateResidencyDataProvider(const std::string& name, std::vector<State> states)
+ : mName(name), mStates(states) {
+ for (const auto& state : mStates) {
+ StateResidency r;
+ r.id = state.id;
+ r.totalTimeInStateMs = 0;
+ r.totalStateEntryCount = 0;
+ r.lastEntryTimestampMs = 0;
+ mResidencies.push_back(r);
+ }
+ }
+ ~FakeStateResidencyDataProvider() = default;
+
+ // Methods from PowerStats::IStateResidencyDataProvider
+ bool getStateResidencies(
+ std::unordered_map<std::string, std::vector<StateResidency>>* residencies) override {
+ for (auto& residency : mResidencies) {
+ mFakeStateResidency.update(&residency);
+ }
+
+ residencies->emplace(mName, mResidencies);
+ return true;
+ }
+
+ std::unordered_map<std::string, std::vector<State>> getInfo() override {
+ return {{mName, mStates}};
+ }
+
+ private:
+ class FakeStateResidency {
+ public:
+ FakeStateResidency() : mDistribution(1, 100) {}
+ void update(StateResidency* residency) {
+ // generates number in the range 1..100
+ auto randNum = std::bind(mDistribution, mGenerator);
+
+ residency->totalTimeInStateMs += randNum() * 100;
+ residency->totalStateEntryCount += randNum();
+ residency->lastEntryTimestampMs += randNum() * 100;
+ }
+
+ private:
+ std::default_random_engine mGenerator;
+ std::uniform_int_distribution<int> mDistribution;
+ };
+
+ const std::string mName;
+ const std::vector<State> mStates;
+ FakeStateResidency mFakeStateResidency;
+ std::vector<StateResidency> mResidencies;
+};
+
+} // namespace stats
+} // namespace power
+} // namespace hardware
+} // namespace android
+} // namespace aidl
\ No newline at end of file
diff --git a/power/stats/aidl/default/PowerStats.cpp b/power/stats/aidl/default/PowerStats.cpp
index 0ffbd08..7cf591e 100644
--- a/power/stats/aidl/default/PowerStats.cpp
+++ b/power/stats/aidl/default/PowerStats.cpp
@@ -18,46 +18,153 @@
#include <android-base/logging.h>
+#include <numeric>
+
namespace aidl {
namespace android {
namespace hardware {
namespace power {
namespace stats {
+void PowerStats::addStateResidencyDataProvider(std::unique_ptr<IStateResidencyDataProvider> p) {
+ if (!p) {
+ return;
+ }
+
+ int32_t id = mPowerEntityInfos.size();
+
+ for (const auto& [entityName, states] : p->getInfo()) {
+ PowerEntity i = {
+ .id = id++,
+ .name = entityName,
+ .states = states,
+ };
+ mPowerEntityInfos.emplace_back(i);
+ mStateResidencyDataProviders.emplace_back(std::move(p));
+ }
+}
+
+void PowerStats::addEnergyConsumer(std::unique_ptr<IEnergyConsumer> p) {
+ if (!p) {
+ return;
+ }
+
+ EnergyConsumerType type = p->getType();
+ std::string name = p->getName();
+ int32_t count = count_if(mEnergyConsumerInfos.begin(), mEnergyConsumerInfos.end(),
+ [&type](const EnergyConsumer& c) { return type == c.type; });
+ int32_t id = mEnergyConsumers.size();
+ mEnergyConsumerInfos.emplace_back(
+ EnergyConsumer{.id = id, .ordinal = count, .type = type, .name = name});
+ mEnergyConsumers.emplace_back(std::move(p));
+}
+
+void PowerStats::setEnergyMeter(std::unique_ptr<IEnergyMeter> p) {
+ mEnergyMeter = std::move(p);
+}
+
ndk::ScopedAStatus PowerStats::getPowerEntityInfo(std::vector<PowerEntity>* _aidl_return) {
- (void)_aidl_return;
+ *_aidl_return = mPowerEntityInfos;
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PowerStats::getStateResidency(const std::vector<int32_t>& in_powerEntityIds,
std::vector<StateResidencyResult>* _aidl_return) {
- (void)in_powerEntityIds;
- (void)_aidl_return;
+ if (mPowerEntityInfos.empty()) {
+ return ndk::ScopedAStatus::ok();
+ }
+
+ // If in_powerEntityIds is empty then return data for all supported entities
+ if (in_powerEntityIds.empty()) {
+ std::vector<int32_t> v(mPowerEntityInfos.size());
+ std::iota(std::begin(v), std::end(v), 0);
+ return getStateResidency(v, _aidl_return);
+ }
+
+ std::unordered_map<std::string, std::vector<StateResidency>> stateResidencies;
+
+ for (const int32_t id : in_powerEntityIds) {
+ // check for invalid ids
+ if (id < 0 || id >= mPowerEntityInfos.size()) {
+ return ndk::ScopedAStatus(AStatus_fromExceptionCode(EX_ILLEGAL_ARGUMENT));
+ }
+
+ // Check to see if we already have data for the given id
+ std::string powerEntityName = mPowerEntityInfos[id].name;
+ if (stateResidencies.find(powerEntityName) == stateResidencies.end()) {
+ mStateResidencyDataProviders[id]->getStateResidencies(&stateResidencies);
+ }
+
+ // Append results if we have them
+ auto stateResidency = stateResidencies.find(powerEntityName);
+ if (stateResidency != stateResidencies.end()) {
+ StateResidencyResult res = {
+ .id = id,
+ .stateResidencyData = stateResidency->second,
+ };
+ _aidl_return->emplace_back(res);
+ } else {
+ // Failed to get results for the given id.
+ LOG(ERROR) << "Failed to get results for " << powerEntityName;
+ }
+ }
+
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PowerStats::getEnergyConsumerInfo(std::vector<EnergyConsumer>* _aidl_return) {
- (void)_aidl_return;
+ *_aidl_return = mEnergyConsumerInfos;
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PowerStats::getEnergyConsumed(const std::vector<int32_t>& in_energyConsumerIds,
std::vector<EnergyConsumerResult>* _aidl_return) {
- (void)in_energyConsumerIds;
- (void)_aidl_return;
+ if (mEnergyConsumers.empty()) {
+ return ndk::ScopedAStatus::ok();
+ }
+
+ // If in_powerEntityIds is empty then return data for all supported energy consumers
+ if (in_energyConsumerIds.empty()) {
+ std::vector<int32_t> v(mEnergyConsumerInfos.size());
+ std::iota(std::begin(v), std::end(v), 0);
+ return getEnergyConsumed(v, _aidl_return);
+ }
+
+ for (const auto id : in_energyConsumerIds) {
+ // check for invalid ids
+ if (id < 0 || id >= mEnergyConsumers.size()) {
+ return ndk::ScopedAStatus(AStatus_fromExceptionCode(EX_ILLEGAL_ARGUMENT));
+ }
+
+ auto optionalResult = mEnergyConsumers[id]->getEnergyConsumed();
+ if (optionalResult) {
+ EnergyConsumerResult result = optionalResult.value();
+ result.id = id;
+ _aidl_return->emplace_back(result);
+ } else {
+ // Failed to get results for the given id.
+ LOG(ERROR) << "Failed to get results for " << mEnergyConsumerInfos[id].name;
+ }
+ }
+
return ndk::ScopedAStatus::ok();
}
ndk::ScopedAStatus PowerStats::getEnergyMeterInfo(std::vector<Channel>* _aidl_return) {
- (void)_aidl_return;
- return ndk::ScopedAStatus::ok();
+ if (!mEnergyMeter) {
+ return ndk::ScopedAStatus::ok();
+ }
+
+ return mEnergyMeter->getEnergyMeterInfo(_aidl_return);
}
ndk::ScopedAStatus PowerStats::readEnergyMeter(const std::vector<int32_t>& in_channelIds,
std::vector<EnergyMeasurement>* _aidl_return) {
- (void)in_channelIds;
- (void)_aidl_return;
- return ndk::ScopedAStatus::ok();
+ if (!mEnergyMeter) {
+ return ndk::ScopedAStatus::ok();
+ }
+
+ return mEnergyMeter->readEnergyMeter(in_channelIds, _aidl_return);
}
} // namespace stats
diff --git a/power/stats/aidl/default/PowerStats.h b/power/stats/aidl/default/PowerStats.h
index cb98e55..f4c5e69 100644
--- a/power/stats/aidl/default/PowerStats.h
+++ b/power/stats/aidl/default/PowerStats.h
@@ -18,6 +18,8 @@
#include <aidl/android/hardware/power/stats/BnPowerStats.h>
+#include <unordered_map>
+
namespace aidl {
namespace android {
namespace hardware {
@@ -26,7 +28,37 @@
class PowerStats : public BnPowerStats {
public:
+ class IStateResidencyDataProvider {
+ public:
+ virtual ~IStateResidencyDataProvider() = default;
+ virtual bool getStateResidencies(
+ std::unordered_map<std::string, std::vector<StateResidency>>* residencies) = 0;
+ virtual std::unordered_map<std::string, std::vector<State>> getInfo() = 0;
+ };
+
+ class IEnergyConsumer {
+ public:
+ virtual ~IEnergyConsumer() = default;
+ virtual std::string getName() = 0;
+ virtual EnergyConsumerType getType() = 0;
+ virtual std::optional<EnergyConsumerResult> getEnergyConsumed() = 0;
+ };
+
+ class IEnergyMeter {
+ public:
+ virtual ~IEnergyMeter() = default;
+ virtual ndk::ScopedAStatus readEnergyMeter(
+ const std::vector<int32_t>& in_channelIds,
+ std::vector<EnergyMeasurement>* _aidl_return) = 0;
+ virtual ndk::ScopedAStatus getEnergyMeterInfo(std::vector<Channel>* _aidl_return) = 0;
+ };
+
PowerStats() = default;
+
+ void addStateResidencyDataProvider(std::unique_ptr<IStateResidencyDataProvider> p);
+ void addEnergyConsumer(std::unique_ptr<IEnergyConsumer> p);
+ void setEnergyMeter(std::unique_ptr<IEnergyMeter> p);
+
// Methods from aidl::android::hardware::power::stats::IPowerStats
ndk::ScopedAStatus getPowerEntityInfo(std::vector<PowerEntity>* _aidl_return) override;
ndk::ScopedAStatus getStateResidency(const std::vector<int32_t>& in_powerEntityIds,
@@ -37,6 +69,15 @@
ndk::ScopedAStatus getEnergyMeterInfo(std::vector<Channel>* _aidl_return) override;
ndk::ScopedAStatus readEnergyMeter(const std::vector<int32_t>& in_channelIds,
std::vector<EnergyMeasurement>* _aidl_return) override;
+
+ private:
+ std::vector<std::unique_ptr<IStateResidencyDataProvider>> mStateResidencyDataProviders;
+ std::vector<PowerEntity> mPowerEntityInfos;
+
+ std::vector<std::unique_ptr<IEnergyConsumer>> mEnergyConsumers;
+ std::vector<EnergyConsumer> mEnergyConsumerInfos;
+
+ std::unique_ptr<IEnergyMeter> mEnergyMeter;
};
} // namespace stats
diff --git a/power/stats/aidl/default/main.cpp b/power/stats/aidl/default/main.cpp
index 0469b4c..2fe3d2e 100644
--- a/power/stats/aidl/default/main.cpp
+++ b/power/stats/aidl/default/main.cpp
@@ -16,16 +16,61 @@
#include "PowerStats.h"
+#include "FakeEnergyConsumer.h"
+#include "FakeEnergyMeter.h"
+#include "FakeStateResidencyDataProvider.h"
+
#include <android-base/logging.h>
#include <android/binder_manager.h>
#include <android/binder_process.h>
+using aidl::android::hardware::power::stats::EnergyConsumerType;
+using aidl::android::hardware::power::stats::FakeEnergyConsumer;
+using aidl::android::hardware::power::stats::FakeEnergyMeter;
+using aidl::android::hardware::power::stats::FakeStateResidencyDataProvider;
using aidl::android::hardware::power::stats::PowerStats;
+using aidl::android::hardware::power::stats::State;
+
+void setFakeEnergyMeter(std::shared_ptr<PowerStats> p) {
+ p->setEnergyMeter(
+ std::make_unique<FakeEnergyMeter>(std::vector<std::pair<std::string, std::string>>{
+ {"Rail1", "Display"},
+ {"Rail2", "CPU"},
+ {"Rail3", "Modem"},
+ }));
+}
+
+void addFakeStateResidencyDataProvider1(std::shared_ptr<PowerStats> p) {
+ p->addStateResidencyDataProvider(std::make_unique<FakeStateResidencyDataProvider>(
+ "CPU", std::vector<State>{{0, "Idle"}, {1, "Active"}}));
+}
+
+void addFakeStateResidencyDataProvider2(std::shared_ptr<PowerStats> p) {
+ p->addStateResidencyDataProvider(std::make_unique<FakeStateResidencyDataProvider>(
+ "Display", std::vector<State>{{0, "Off"}, {1, "On"}}));
+}
+
+void addFakeEnergyConsumer1(std::shared_ptr<PowerStats> p) {
+ p->addEnergyConsumer(std::make_unique<FakeEnergyConsumer>(EnergyConsumerType::OTHER, "GPU"));
+}
+
+void addFakeEnergyConsumer2(std::shared_ptr<PowerStats> p) {
+ p->addEnergyConsumer(
+ std::make_unique<FakeEnergyConsumer>(EnergyConsumerType::MOBILE_RADIO, "MODEM"));
+}
int main() {
ABinderProcess_setThreadPoolMaxThreadCount(0);
std::shared_ptr<PowerStats> p = ndk::SharedRefBase::make<PowerStats>();
+ setFakeEnergyMeter(p);
+
+ addFakeStateResidencyDataProvider1(p);
+ addFakeStateResidencyDataProvider2(p);
+
+ addFakeEnergyConsumer1(p);
+ addFakeEnergyConsumer2(p);
+
const std::string instance = std::string() + PowerStats::descriptor + "/default";
binder_status_t status = AServiceManager_addService(p->asBinder().get(), instance.c_str());
CHECK(status == STATUS_OK);
diff --git a/radio/1.0/vts/functional/Android.bp b/radio/1.0/vts/functional/Android.bp
index 9e92d93..2c0e70a 100644
--- a/radio/1.0/vts/functional/Android.bp
+++ b/radio/1.0/vts/functional/Android.bp
@@ -43,6 +43,8 @@
],
static_libs: [
"android.hardware.radio@1.0",
+ "android.hardware.radio@1.1",
+ "android.hardware.radio@1.2",
],
test_config: "vts_hal_radio_target_test.xml",
test_suites: [
diff --git a/radio/1.0/vts/functional/radio_hidl_hal_data.cpp b/radio/1.0/vts/functional/radio_hidl_hal_data.cpp
index e3ee9d4..655b869 100644
--- a/radio/1.0/vts/functional/radio_hidl_hal_data.cpp
+++ b/radio/1.0/vts/functional/radio_hidl_hal_data.cpp
@@ -15,6 +15,7 @@
*/
#include <android-base/logging.h>
+#include <android/hardware/radio/1.2/IRadio.h>
#include <radio_hidl_hal_utils_v1_0.h>
using namespace ::android::hardware::radio::V1_0;
@@ -139,6 +140,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // setupDataCall is deprecated on radio::V1_2 with setupDataCall_1_2
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp->rspInfo.error,
{RadioError::NONE, RadioError::OP_NOT_ALLOWED_BEFORE_REG_TO_NW,
@@ -164,6 +168,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // deactivateDataCall is deprecated on radio::V1_2 with deactiveDataCall_1_2
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp->rspInfo.error,
{RadioError::NONE, RadioError::RADIO_NOT_AVAILABLE,
diff --git a/radio/1.0/vts/functional/radio_hidl_hal_misc.cpp b/radio/1.0/vts/functional/radio_hidl_hal_misc.cpp
index 3f96473..624d003 100644
--- a/radio/1.0/vts/functional/radio_hidl_hal_misc.cpp
+++ b/radio/1.0/vts/functional/radio_hidl_hal_misc.cpp
@@ -15,6 +15,7 @@
*/
#include <android-base/logging.h>
+#include <android/hardware/radio/1.2/IRadio.h>
#include <radio_hidl_hal_utils_v1_0.h>
/*
@@ -771,6 +772,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // HAL 1.2 and later use the always-on LCE that relies on indications.
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(
radioRsp->rspInfo.error,
@@ -792,6 +796,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // HAL 1.2 and later use the always-on LCE that relies on indications.
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp->rspInfo.error,
{RadioError::NONE, RadioError::LCE_NOT_SUPPORTED,
@@ -812,6 +819,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // HAL 1.2 and later use the always-on LCE that relies on indications.
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp->rspInfo.error,
{RadioError::NONE, RadioError::INTERNAL_ERR,
@@ -971,6 +981,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // setIndicationFilter is deprecated on radio::V1_2 with setIndicationFilter_1_2
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
std::cout << static_cast<int>(radioRsp->rspInfo.error) << std::endl;
if (cardStatus.cardState == CardState::ABSENT) {
@@ -992,6 +1005,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp->rspInfo.type);
EXPECT_EQ(serial, radioRsp->rspInfo.serial);
+ // setSimCardPower is deprecated on radio::V1_1 with setSimCardPower_1_1
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_1);
+
if (cardStatus.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp->rspInfo.error,
{RadioError::NONE, RadioError::REQUEST_NOT_SUPPORTED}));
diff --git a/radio/1.0/vts/functional/radio_hidl_hal_utils_v1_0.h b/radio/1.0/vts/functional/radio_hidl_hal_utils_v1_0.h
index 8a551f7..e3e9473 100644
--- a/radio/1.0/vts/functional/radio_hidl_hal_utils_v1_0.h
+++ b/radio/1.0/vts/functional/radio_hidl_hal_utils_v1_0.h
@@ -38,6 +38,8 @@
#define TIMEOUT_PERIOD 75
#define RADIO_SERVICE_NAME "slot1"
+#define SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(__ver__) \
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL(__ver__, radio, radioRsp)
class RadioHidlTest;
extern CardStatus cardStatus;
diff --git a/radio/1.0/vts/functional/vts_test_util.cpp b/radio/1.0/vts/functional/vts_test_util.cpp
index 9a2d089..fc37201 100644
--- a/radio/1.0/vts/functional/vts_test_util.cpp
+++ b/radio/1.0/vts/functional/vts_test_util.cpp
@@ -19,6 +19,8 @@
#include <iostream>
#include "VtsCoreUtil.h"
+#define WAIT_TIMEOUT_PERIOD 75
+
int GetRandomSerialNumber() {
return rand();
}
@@ -99,4 +101,33 @@
::android::hardware::radio::V1_0::RegState::NOT_REG_MT_SEARCHING_OP_EM == state ||
::android::hardware::radio::V1_0::RegState::REG_DENIED_EM == state ||
::android::hardware::radio::V1_0::RegState::UNKNOWN_EM == state;
-}
\ No newline at end of file
+}
+
+/*
+ * Notify that the response message is received.
+ */
+void RadioResponseWaiter::notify(int receivedSerial) {
+ std::unique_lock<std::mutex> lock(mtx_);
+ if (serial == receivedSerial) {
+ count_++;
+ cv_.notify_one();
+ }
+}
+
+/*
+ * Wait till the response message is notified or till WAIT_TIMEOUT_PERIOD.
+ */
+std::cv_status RadioResponseWaiter::wait() {
+ std::unique_lock<std::mutex> lock(mtx_);
+
+ std::cv_status status = std::cv_status::no_timeout;
+ auto now = std::chrono::system_clock::now();
+ while (count_ == 0) {
+ status = cv_.wait_until(lock, now + std::chrono::seconds(WAIT_TIMEOUT_PERIOD));
+ if (status == std::cv_status::timeout) {
+ return status;
+ }
+ }
+ count_--;
+ return status;
+}
diff --git a/radio/1.0/vts/functional/vts_test_util.h b/radio/1.0/vts/functional/vts_test_util.h
index 218e823..eeb1d29 100644
--- a/radio/1.0/vts/functional/vts_test_util.h
+++ b/radio/1.0/vts/functional/vts_test_util.h
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#pragma once
+
#include <android-base/logging.h>
#include <android/hardware/radio/1.0/types.h>
@@ -25,6 +27,20 @@
using ::android::hardware::radio::V1_0::SapResultCode;
using namespace std;
+/*
+ * MACRO used to skip test case when radio response return error REQUEST_NOT_SUPPORTED
+ * on HAL versions which has deprecated the request interfaces. The MACRO can only be used
+ * AFTER receiving radio response.
+ */
+#define SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL(__ver__, __radio__, __radioRsp__) \
+ do { \
+ sp<::android::hardware::radio::V##__ver__::IRadio> __radio = \
+ ::android::hardware::radio::V##__ver__::IRadio::castFrom(__radio__); \
+ if (__radio && __radioRsp__->rspInfo.error == RadioError::REQUEST_NOT_SUPPORTED) { \
+ GTEST_SKIP() << "REQUEST_NOT_SUPPORTED"; \
+ } \
+ } while (0)
+
enum CheckFlag {
CHECK_DEFAULT = 0,
CHECK_GENERAL_ERROR = 1,
@@ -81,4 +97,24 @@
/*
* Check if voice status is in service.
*/
-bool isVoiceInService(RegState state);
\ No newline at end of file
+bool isVoiceInService(RegState state);
+
+/**
+ * Used when waiting for an asynchronous response from the HAL.
+ */
+class RadioResponseWaiter {
+ protected:
+ std::mutex mtx_;
+ std::condition_variable cv_;
+ int count_;
+
+ public:
+ /* Serial number for radio request */
+ int serial;
+
+ /* Used as a mechanism to inform the test about data/event callback */
+ void notify(int receivedSerial);
+
+ /* Test code calls this function to wait for response */
+ std::cv_status wait();
+};
diff --git a/radio/1.1/vts/functional/Android.bp b/radio/1.1/vts/functional/Android.bp
index 3ada6ff..b3def8e 100644
--- a/radio/1.1/vts/functional/Android.bp
+++ b/radio/1.1/vts/functional/Android.bp
@@ -35,6 +35,7 @@
],
static_libs: [
"RadioVtsTestUtilBase",
+ "android.hardware.radio@1.2",
"android.hardware.radio@1.1",
"android.hardware.radio@1.0",
],
diff --git a/radio/1.1/vts/functional/radio_hidl_hal_api.cpp b/radio/1.1/vts/functional/radio_hidl_hal_api.cpp
index 08121fd..389944b 100644
--- a/radio/1.1/vts/functional/radio_hidl_hal_api.cpp
+++ b/radio/1.1/vts/functional/radio_hidl_hal_api.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include <android/hardware/radio/1.2/IRadio.h>
#include <radio_hidl_hal_utils_v1_1.h>
#include <vector>
@@ -107,6 +108,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_1->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_1->rspInfo.serial);
+ // startNetworkScan is deprecated on radio::V1_2 with startNetworkScan_1_2
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ALOGI("startNetworkScan, rspInfo.error = %d\n", (int32_t)radioRsp_v1_1->rspInfo.error);
ASSERT_TRUE(CheckAnyOfErrors(
@@ -131,6 +135,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_1->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_1->rspInfo.serial);
+ // startNetworkScan is deprecated on radio::V1_2 with startNetworkScan_1_2
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_2);
+
if (cardStatus.cardState == CardState::ABSENT) {
ALOGI("startNetworkScan_InvalidArgument, rspInfo.error = %d\n",
(int32_t)radioRsp_v1_1->rspInfo.error);
diff --git a/radio/1.1/vts/functional/radio_hidl_hal_utils_v1_1.h b/radio/1.1/vts/functional/radio_hidl_hal_utils_v1_1.h
index b81ee13..bafde77 100644
--- a/radio/1.1/vts/functional/radio_hidl_hal_utils_v1_1.h
+++ b/radio/1.1/vts/functional/radio_hidl_hal_utils_v1_1.h
@@ -40,6 +40,8 @@
#define TIMEOUT_PERIOD 75
#define RADIO_SERVICE_NAME "slot1"
+#define SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(__ver__) \
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL(__ver__, radio_v1_1, radioRsp_v1_1)
class RadioHidlTest_v1_1;
extern CardStatus cardStatus;
diff --git a/radio/1.2/vts/functional/Android.bp b/radio/1.2/vts/functional/Android.bp
index 1447ade..a62000f 100644
--- a/radio/1.2/vts/functional/Android.bp
+++ b/radio/1.2/vts/functional/Android.bp
@@ -36,6 +36,8 @@
],
static_libs: [
"RadioVtsTestUtilBase",
+ "android.hardware.radio@1.4",
+ "android.hardware.radio@1.3",
"android.hardware.radio@1.2",
"android.hardware.radio@1.1",
"android.hardware.radio@1.0",
diff --git a/radio/1.2/vts/functional/radio_hidl_hal_api.cpp b/radio/1.2/vts/functional/radio_hidl_hal_api.cpp
index acb1b0e..2400bde 100644
--- a/radio/1.2/vts/functional/radio_hidl_hal_api.cpp
+++ b/radio/1.2/vts/functional/radio_hidl_hal_api.cpp
@@ -14,6 +14,7 @@
* limitations under the License.
*/
+#include <android/hardware/radio/1.4/IRadio.h>
#include <radio_hidl_hal_utils_v1_2.h>
#include <vector>
@@ -57,6 +58,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan, rspInfo.error = %s\n", toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
ASSERT_TRUE(CheckAnyOfErrors(radioRsp_v1_2->rspInfo.error, {RadioError::SIM_ABSENT}));
@@ -94,6 +98,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidArgument, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -126,6 +133,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidInterval1, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -158,6 +168,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidInterval2, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -190,6 +203,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidMaxSearchTime1, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -222,6 +238,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidMaxSearchTime2, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -254,6 +273,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidPeriodicity1, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -286,6 +308,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidPeriodicity2, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -322,6 +347,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidArgument, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
@@ -359,6 +387,9 @@
EXPECT_EQ(RadioResponseType::SOLICITED, radioRsp_v1_2->rspInfo.type);
EXPECT_EQ(serial, radioRsp_v1_2->rspInfo.serial);
+ // startNetworkScan_1_2 is deprecated in radio::V1_4 with startNetworkScan_1_4
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(1_4);
+
ALOGI("startNetworkScan_InvalidArgument, rspInfo.error = %s\n",
toString(radioRsp_v1_2->rspInfo.error).c_str());
if (cardStatus.base.cardState == CardState::ABSENT) {
diff --git a/radio/1.2/vts/functional/radio_hidl_hal_utils_v1_2.h b/radio/1.2/vts/functional/radio_hidl_hal_utils_v1_2.h
index 479340c..81286d2 100644
--- a/radio/1.2/vts/functional/radio_hidl_hal_utils_v1_2.h
+++ b/radio/1.2/vts/functional/radio_hidl_hal_utils_v1_2.h
@@ -50,6 +50,8 @@
#define TIMEOUT_PERIOD 75
#define RADIO_SERVICE_NAME "slot1"
+#define SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL_VERSION_AT_LEAST(__ver__) \
+ SKIP_TEST_IF_REQUEST_NOT_SUPPORTED_WITH_HAL(__ver__, radio_v1_2, radioRsp_v1_2)
class RadioHidlTest_v1_2;
extern ::android::hardware::radio::V1_2::CardStatus cardStatus;
@@ -682,4 +684,4 @@
/* radio config service handle */
sp<IRadioConfig> radioConfig;
-};
\ No newline at end of file
+};
diff --git a/radio/1.6/types.hal b/radio/1.6/types.hal
index 6400c63..95eba69 100644
--- a/radio/1.6/types.hal
+++ b/radio/1.6/types.hal
@@ -813,6 +813,11 @@
* see: 3GPP TS 24.501 Section 9.11.2.8.
*/
int32_t mappedHplmnSD;
+
+ /**
+ * Field to indicate the current status of the slice.
+ */
+ SliceStatus status;
};
/**
@@ -986,9 +991,9 @@
*/
vec<UrspRule> urspRules;
/**
- * Struct containing all NSSAIs (list of slice info).
+ * List of all slices.
*/
- Nssais nssais;
+ vec<SliceInfo> sliceInfo;
};
/**
@@ -1011,7 +1016,6 @@
vec<RouteSelectionDescriptor> routeSelectionDescriptor;
};
-
/**
* This struct represents a single route selection descriptor as defined in
* 3GPP TS 24.526.
@@ -1067,47 +1071,13 @@
SscMode value;
};
-/**
- * This struct contains all NSSAIs (lists of slices).
- */
-struct Nssais {
- /**
- * These are all the slices configured by the network. This includes allowed
- * and rejected slices, as well as slices that are neither allowed nor rejected
- * yet. Empty vector indicates that no slices are configured, and in that case
- * allowed and rejected vectors must be empty as well.
- */
- vec<SliceInfo> configured;
- /**
- * These are all the slices that the UE is allowed to use. All these slices
- * must be configured as well. Empty vector indicates that no slices are
- * allowed yet.
- */
- vec<SliceInfo> allowed;
- /**
- * These are all the slices that the UE is not allowed to use. All these slices
- * must be configured as well. Empty vector indicates that no slices are
- * rejected yet.
- */
- vec<RejectedSliceInfo> rejected;
- /**
- * Default configured NSSAI
- */
- vec<SliceInfo> defaultConfigured;
-};
-
-/**
- * This struct represents a network slice rejected by the network. It contains a
- * rejectionCause corresponding to a rejected network slice.
- */
-struct RejectedSliceInfo {
- SliceInfo sliceInfo;
- SliceRejectionCause rejectionCause;
-};
-
-enum SliceRejectionCause : int32_t {
- NOT_AVAILABLE_IN_PLMN,
- NOT_AVAILABLE_IN_REG_AREA,
+enum SliceStatus : int32_t {
+ UNKNOWN,
+ CONFIGURED, // Configured but not allowed or rejected yet
+ ALLOWED, // Allowed to be used
+ REJECTED_NOT_AVAILABLE_IN_PLMN, // Rejected because not available in PLMN
+ REJECTED_NOT_AVAILABLE_IN_REG_AREA, // Rejected because not available in reg area
+ DEFAULT_CONFIGURED, // Considered valid when configured/allowed slices are not available
};
/**
diff --git a/radio/1.6/vts/functional/Android.bp b/radio/1.6/vts/functional/Android.bp
index dde718b..65b0dd0 100644
--- a/radio/1.6/vts/functional/Android.bp
+++ b/radio/1.6/vts/functional/Android.bp
@@ -36,6 +36,7 @@
],
static_libs: [
"RadioVtsTestUtilBase",
+ "RadioConfigVtsTestResponse",
"android.hardware.radio@1.6",
"android.hardware.radio@1.5",
"android.hardware.radio@1.4",
@@ -45,8 +46,13 @@
"android.hardware.radio@1.0",
"android.hardware.radio.config@1.0",
"android.hardware.radio.config@1.1",
+ "android.hardware.radio.config@1.2",
+ "android.hardware.radio.config@1.3",
],
- header_libs: ["radio.util.header@1.0"],
+ header_libs: [
+ "radio.util.header@1.0",
+ "radio.config.util.header@1.3",
+ ],
test_suites: [
"general-tests",
"vts",
diff --git a/radio/1.6/vts/functional/radio_hidl_hal_test.cpp b/radio/1.6/vts/functional/radio_hidl_hal_test.cpp
index 59f7682..6255f66 100644
--- a/radio/1.6/vts/functional/radio_hidl_hal_test.cpp
+++ b/radio/1.6/vts/functional/radio_hidl_hal_test.cpp
@@ -45,35 +45,6 @@
EXPECT_EQ(CardState::PRESENT, cardStatus.base.base.base.cardState);
}
-/*
- * Notify that the response message is received.
- */
-void RadioHidlTest_v1_6::notify(int receivedSerial) {
- std::unique_lock<std::mutex> lock(mtx_);
- if (serial == receivedSerial) {
- count_++;
- cv_.notify_one();
- }
-}
-
-/*
- * Wait till the response message is notified or till TIMEOUT_PERIOD.
- */
-std::cv_status RadioHidlTest_v1_6::wait() {
- std::unique_lock<std::mutex> lock(mtx_);
-
- std::cv_status status = std::cv_status::no_timeout;
- auto now = std::chrono::system_clock::now();
- while (count_ == 0) {
- status = cv_.wait_until(lock, now + std::chrono::seconds(TIMEOUT_PERIOD));
- if (status == std::cv_status::timeout) {
- return status;
- }
- }
- count_--;
- return status;
-}
-
void RadioHidlTest_v1_6::clearPotentialEstablishedCalls() {
// Get the current call Id to hangup the established emergency call.
serial = GetRandomSerialNumber();
@@ -108,3 +79,31 @@
radio_v1_6->getDataCallList_1_6(serial);
EXPECT_EQ(std::cv_status::no_timeout, wait());
}
+
+/**
+ * Specific features on the Radio Hal rely on Radio Hal Capabilities. The VTS
+ * tests related to that features must not run if the related capability is
+ * disabled.
+ * <p/>
+ * Typical usage within VTS:
+ * if (getRadioHalCapabilities().modemReducedFeatureSet) return;
+ */
+HalDeviceCapabilities RadioHidlTest_v1_6::getRadioHalCapabilities() {
+ sp<::android::hardware::radio::config::V1_3::IRadioConfig> radioConfig_v1_3 =
+ ::android::hardware::radio::config::V1_3::IRadioConfig::getService();
+ if (radioConfig_v1_3.get() == nullptr) {
+ // If v1_3 isn't present, the values are initialized to false
+ HalDeviceCapabilities radioHalCapabilities;
+ memset(&radioHalCapabilities, 0, sizeof(radioHalCapabilities));
+ return radioHalCapabilities;
+ } else {
+ // Get radioHalDeviceCapabilities from the radio config
+ sp<RadioConfigResponse> radioConfigRsp = new (std::nothrow) RadioConfigResponse(*this);
+ radioConfig_v1_3->setResponseFunctions(radioConfigRsp, nullptr);
+ serial = GetRandomSerialNumber();
+
+ radioConfig_v1_3->getHalDeviceCapabilities(serial);
+ EXPECT_EQ(std::cv_status::no_timeout, wait());
+ return radioConfigRsp->halDeviceCapabilities;
+ }
+}
diff --git a/radio/1.6/vts/functional/radio_hidl_hal_utils_v1_6.h b/radio/1.6/vts/functional/radio_hidl_hal_utils_v1_6.h
index f610f2a..23378b5 100644
--- a/radio/1.6/vts/functional/radio_hidl_hal_utils_v1_6.h
+++ b/radio/1.6/vts/functional/radio_hidl_hal_utils_v1_6.h
@@ -18,16 +18,12 @@
#include <android-base/logging.h>
-#include <gtest/gtest.h>
-#include <hidl/GtestPrinter.h>
-#include <hidl/ServiceManagement.h>
-#include <utils/Log.h>
+#include "radio_config_hidl_hal_utils.h"
+
#include <chrono>
#include <condition_variable>
#include <mutex>
-#include <android/hardware/radio/config/1.1/IRadioConfig.h>
-
#include <android/hardware/radio/1.6/IRadio.h>
#include <android/hardware/radio/1.6/IRadioIndication.h>
#include <android/hardware/radio/1.6/IRadioResponse.h>
@@ -42,14 +38,15 @@
using namespace ::android::hardware::radio::V1_2;
using namespace ::android::hardware::radio::V1_1;
using namespace ::android::hardware::radio::V1_0;
+using namespace ::android::hardware::radio::config::V1_3;
using ::android::sp;
using ::android::hardware::hidl_string;
using ::android::hardware::hidl_vec;
using ::android::hardware::Return;
using ::android::hardware::Void;
+using ::android::hardware::radio::config::V1_3::HalDeviceCapabilities;
-#define TIMEOUT_PERIOD 75
#define MODEM_EMERGENCY_CALL_ESTABLISH_TIME 3
#define MODEM_EMERGENCY_CALL_DISCONNECT_TIME 3
@@ -61,7 +58,7 @@
/* Callback class for radio response v1_6 */
class RadioResponse_v1_6 : public ::android::hardware::radio::V1_6::IRadioResponse {
protected:
- RadioHidlTest_v1_6& parent_v1_6;
+ RadioResponseWaiter& parent_v1_6;
public:
hidl_vec<RadioBandMode> radioBandModes;
@@ -105,7 +102,7 @@
::android::hardware::radio::V1_5::CellIdentity barringCellIdentity;
::android::hardware::hidl_vec<::android::hardware::radio::V1_5::BarringInfo> barringInfos;
- RadioResponse_v1_6(RadioHidlTest_v1_6& parent_v1_6);
+ RadioResponse_v1_6(RadioResponseWaiter& parent_v1_6);
virtual ~RadioResponse_v1_6() = default;
Return<void> getIccCardStatusResponse(
@@ -1079,15 +1076,9 @@
};
// The main test class for Radio HIDL.
-class RadioHidlTest_v1_6 : public ::testing::TestWithParam<std::string> {
+class RadioHidlTest_v1_6 : public ::testing::TestWithParam<std::string>,
+ public RadioResponseWaiter {
protected:
- std::mutex mtx_;
- std::condition_variable cv_;
- int count_;
-
- /* Serial number for radio request */
- int serial;
-
/* Clear Potential Established Calls */
void clearPotentialEstablishedCalls();
@@ -1100,11 +1091,7 @@
public:
virtual void SetUp() override;
- /* Used as a mechanism to inform the test about data/event callback */
- void notify(int receivedSerial);
-
- /* Test code calls this function to wait for response */
- std::cv_status wait();
+ HalDeviceCapabilities getRadioHalCapabilities();
/* radio service handle */
sp<::android::hardware::radio::V1_6::IRadio> radio_v1_6;
diff --git a/radio/1.6/vts/functional/radio_response.cpp b/radio/1.6/vts/functional/radio_response.cpp
index d9da40a..8034fd2 100644
--- a/radio/1.6/vts/functional/radio_response.cpp
+++ b/radio/1.6/vts/functional/radio_response.cpp
@@ -18,7 +18,7 @@
::android::hardware::radio::V1_5::CardStatus cardStatus;
-RadioResponse_v1_6::RadioResponse_v1_6(RadioHidlTest_v1_6& parent) : parent_v1_6(parent) {}
+RadioResponse_v1_6::RadioResponse_v1_6(RadioResponseWaiter& parent) : parent_v1_6(parent) {}
/* 1.0 Apis */
Return<void> RadioResponse_v1_6::getIccCardStatusResponse(
diff --git a/radio/config/1.3/vts/functional/Android.bp b/radio/config/1.3/vts/functional/Android.bp
index aa3522d..20c480f 100644
--- a/radio/config/1.3/vts/functional/Android.bp
+++ b/radio/config/1.3/vts/functional/Android.bp
@@ -46,3 +46,26 @@
"vts",
],
}
+
+cc_library_static {
+ name: "RadioConfigVtsTestResponse",
+ defaults: ["VtsHalTargetTestDefaults"],
+ srcs : [
+ "radio_config_response.cpp",
+ "radio_config_hidl_hal_test.cpp",
+ ],
+ header_libs: ["radio.util.header@1.0"],
+ static_libs: ["RadioVtsTestUtilBase"],
+ shared_libs: [
+ "android.hardware.radio@1.0",
+ "android.hardware.radio.config@1.0",
+ "android.hardware.radio.config@1.1",
+ "android.hardware.radio.config@1.2",
+ "android.hardware.radio.config@1.3",
+ ],
+}
+
+cc_library_headers {
+ name: "radio.config.util.header@1.3",
+ export_include_dirs: ["."],
+}
diff --git a/radio/config/1.3/vts/functional/radio_config_hidl_hal_test.cpp b/radio/config/1.3/vts/functional/radio_config_hidl_hal_test.cpp
index de8365a..da61464 100644
--- a/radio/config/1.3/vts/functional/radio_config_hidl_hal_test.cpp
+++ b/radio/config/1.3/vts/functional/radio_config_hidl_hal_test.cpp
@@ -31,32 +31,3 @@
radioConfig->setResponseFunctions(radioConfigRsp, nullptr);
}
-
-/*
- * Notify that the response message is received.
- */
-void RadioConfigHidlTest::notify(int receivedSerial) {
- std::unique_lock<std::mutex> lock(mtx_);
- if (serial == receivedSerial) {
- count_++;
- cv_.notify_one();
- }
-}
-
-/*
- * Wait till the response message is notified or till TIMEOUT_PERIOD.
- */
-std::cv_status RadioConfigHidlTest::wait() {
- std::unique_lock<std::mutex> lock(mtx_);
-
- std::cv_status status = std::cv_status::no_timeout;
- auto now = std::chrono::system_clock::now();
- while (count_ == 0) {
- status = cv_.wait_until(lock, now + std::chrono::seconds(TIMEOUT_PERIOD));
- if (status == std::cv_status::timeout) {
- return status;
- }
- }
- count_--;
- return status;
-}
diff --git a/radio/config/1.3/vts/functional/radio_config_hidl_hal_utils.h b/radio/config/1.3/vts/functional/radio_config_hidl_hal_utils.h
index 439eb70..895ae08 100644
--- a/radio/config/1.3/vts/functional/radio_config_hidl_hal_utils.h
+++ b/radio/config/1.3/vts/functional/radio_config_hidl_hal_utils.h
@@ -14,6 +14,8 @@
* limitations under the License.
*/
+#pragma once
+
#include <android-base/logging.h>
#include <chrono>
@@ -49,7 +51,6 @@
using ::android::hardware::radio::config::V1_3::IRadioConfig;
using ::android::hardware::radio::V1_0::RadioResponseInfo;
-#define TIMEOUT_PERIOD 75
#define RADIO_SERVICE_NAME "slot1"
class RadioConfigHidlTest;
@@ -57,13 +58,14 @@
/* Callback class for radio config response */
class RadioConfigResponse : public IRadioConfigResponse {
protected:
- RadioConfigHidlTest& parent;
+ RadioResponseWaiter& parent;
public:
RadioResponseInfo rspInfo;
PhoneCapability phoneCap;
+ HalDeviceCapabilities halDeviceCapabilities;
- RadioConfigResponse(RadioConfigHidlTest& parent);
+ RadioConfigResponse(RadioResponseWaiter& parent);
virtual ~RadioConfigResponse() = default;
Return<void> getSimSlotsStatusResponse(
@@ -107,26 +109,13 @@
};
// The main test class for Radio config HIDL.
-class RadioConfigHidlTest : public ::testing::TestWithParam<std::string> {
- protected:
- std::mutex mtx_;
- std::condition_variable cv_;
- int count_;
-
+class RadioConfigHidlTest : public ::testing::TestWithParam<std::string>,
+ public RadioResponseWaiter {
public:
virtual void SetUp() override;
- /* Used as a mechanism to inform the test about data/event callback */
- void notify(int receivedSerial);
-
- /* Test code calls this function to wait for response */
- std::cv_status wait();
-
void updateSimCardStatus();
- /* Serial number for radio request */
- int serial;
-
/* radio config service handle */
sp<IRadioConfig> radioConfig;
diff --git a/radio/config/1.3/vts/functional/radio_config_response.cpp b/radio/config/1.3/vts/functional/radio_config_response.cpp
index 2a8b28b..11e3cce 100644
--- a/radio/config/1.3/vts/functional/radio_config_response.cpp
+++ b/radio/config/1.3/vts/functional/radio_config_response.cpp
@@ -18,7 +18,7 @@
// SimSlotStatus slotStatus;
-RadioConfigResponse::RadioConfigResponse(RadioConfigHidlTest& parent) : parent(parent) {}
+RadioConfigResponse::RadioConfigResponse(RadioResponseWaiter& parent) : parent(parent) {}
Return<void> RadioConfigResponse::getSimSlotsStatusResponse(
const ::android::hardware::radio::V1_0::RadioResponseInfo& /* info */,
@@ -65,6 +65,7 @@
Return<void> RadioConfigResponse::getHalDeviceCapabilitiesResponse(
const ::android::hardware::radio::V1_6::RadioResponseInfo& /* info */,
- const ::android::hardware::radio::config::V1_3::HalDeviceCapabilities& /* capabilities */) {
+ const ::android::hardware::radio::config::V1_3::HalDeviceCapabilities& capabilities) {
+ halDeviceCapabilities = capabilities;
return Void();
-}
\ No newline at end of file
+}
diff --git a/tv/tuner/1.0/default/Android.bp b/tv/tuner/1.0/default/Android.bp
index c85fbdf..ae15b6c 100644
--- a/tv/tuner/1.0/default/Android.bp
+++ b/tv/tuner/1.0/default/Android.bp
@@ -33,7 +33,7 @@
"libfmq",
"libhidlbase",
"libhidlmemory",
- "libion",
+ "libdmabufheap",
"liblog",
"libstagefright_foundation",
"libutils",
diff --git a/tv/tuner/1.0/default/Filter.cpp b/tv/tuner/1.0/default/Filter.cpp
index ce748e5..7b50f8c 100644
--- a/tv/tuner/1.0/default/Filter.cpp
+++ b/tv/tuner/1.0/default/Filter.cpp
@@ -16,9 +16,11 @@
#define LOG_TAG "android.hardware.tv.tuner@1.0-Filter"
-#include "Filter.h"
+#include <BufferAllocator/BufferAllocator.h>
#include <utils/Log.h>
+#include "Filter.h"
+
namespace android {
namespace hardware {
namespace tv {
@@ -622,15 +624,15 @@
}
int Filter::createAvIonFd(int size) {
- // Create an ion fd and allocate an av fd mapped to a buffer to it.
- int ion_fd = ion_open();
- if (ion_fd == -1) {
- ALOGE("[Filter] Failed to open ion fd %d", errno);
+ // Create an DMA-BUF fd and allocate an av fd mapped to a buffer to it.
+ auto buffer_allocator = std::make_unique<BufferAllocator>();
+ if (!buffer_allocator) {
+ ALOGE("[Filter] Unable to create BufferAllocator object");
return -1;
}
int av_fd = -1;
- ion_alloc_fd(dup(ion_fd), size, 0 /*align*/, ION_HEAP_SYSTEM_MASK, 0 /*flags*/, &av_fd);
- if (av_fd == -1) {
+ av_fd = buffer_allocator->Alloc("system-uncached", size);
+ if (av_fd < 0) {
ALOGE("[Filter] Failed to create av fd %d", errno);
return -1;
}
diff --git a/tv/tuner/1.1/default/Android.bp b/tv/tuner/1.1/default/Android.bp
index 86025cf..a612802 100644
--- a/tv/tuner/1.1/default/Android.bp
+++ b/tv/tuner/1.1/default/Android.bp
@@ -31,6 +31,7 @@
"android.hardware.tv.tuner@1.1",
"android.hidl.memory@1.0",
"libcutils",
+ "libdmabufheap",
"libfmq",
"libhidlbase",
"libhidlmemory",
diff --git a/tv/tuner/1.1/default/Filter.cpp b/tv/tuner/1.1/default/Filter.cpp
index aec1fd0..7d609ea 100644
--- a/tv/tuner/1.1/default/Filter.cpp
+++ b/tv/tuner/1.1/default/Filter.cpp
@@ -16,9 +16,11 @@
#define LOG_TAG "android.hardware.tv.tuner@1.1-Filter"
-#include "Filter.h"
+#include <BufferAllocator/BufferAllocator.h>
#include <utils/Log.h>
+#include "Filter.h"
+
namespace android {
namespace hardware {
namespace tv {
@@ -259,11 +261,14 @@
int av_fd = createAvIonFd(BUFFER_SIZE_16M);
if (av_fd == -1) {
_hidl_cb(Result::UNKNOWN_ERROR, NULL, 0);
+ return Void();
}
native_handle_t* nativeHandle = createNativeHandle(av_fd);
if (nativeHandle == NULL) {
+ ::close(av_fd);
_hidl_cb(Result::UNKNOWN_ERROR, NULL, 0);
+ return Void();
}
mSharedAvMemHandle.setTo(nativeHandle, /*shouldOwn=*/true);
::close(av_fd);
@@ -826,15 +831,15 @@
}
int Filter::createAvIonFd(int size) {
- // Create an ion fd and allocate an av fd mapped to a buffer to it.
- int ion_fd = ion_open();
- if (ion_fd == -1) {
- ALOGE("[Filter] Failed to open ion fd %d", errno);
+ // Create an DMA-BUF fd and allocate an av fd mapped to a buffer to it.
+ auto buffer_allocator = std::make_unique<BufferAllocator>();
+ if (!buffer_allocator) {
+ ALOGE("[Filter] Unable to create BufferAllocator object");
return -1;
}
int av_fd = -1;
- ion_alloc_fd(dup(ion_fd), size, 0 /*align*/, ION_HEAP_SYSTEM_MASK, 0 /*flags*/, &av_fd);
- if (av_fd == -1) {
+ av_fd = buffer_allocator->Alloc("system-uncached", size);
+ if (av_fd < 0) {
ALOGE("[Filter] Failed to create av fd %d", errno);
return -1;
}
diff --git a/wifi/1.5/default/hidl_struct_util.cpp b/wifi/1.5/default/hidl_struct_util.cpp
index cd0edbe..baa898e 100644
--- a/wifi/1.5/default/hidl_struct_util.cpp
+++ b/wifi/1.5/default/hidl_struct_util.cpp
@@ -1077,6 +1077,17 @@
legacy_stats.iface.ac[legacy_hal::WIFI_AC_VO].contention_num_samples;
hidl_stats->iface.timeSliceDutyCycleInPercent =
legacy_stats.iface.info.time_slicing_duty_cycle_percent;
+ // peer info legacy_stats conversion.
+ std::vector<StaPeerInfo> hidl_peers_info_stats;
+ for (const auto& legacy_peer_info_stats : legacy_stats.peers) {
+ StaPeerInfo hidl_peer_info_stats;
+ if (!convertLegacyPeerInfoStatsToHidl(legacy_peer_info_stats,
+ &hidl_peer_info_stats)) {
+ return false;
+ }
+ hidl_peers_info_stats.push_back(hidl_peer_info_stats);
+ }
+ hidl_stats->iface.peers = hidl_peers_info_stats;
// radio legacy_stats conversion.
std::vector<V1_3::StaLinkLayerRadioStats> hidl_radios_stats;
for (const auto& legacy_radio_stats : legacy_stats.radios) {
@@ -1094,6 +1105,35 @@
return true;
}
+bool convertLegacyPeerInfoStatsToHidl(
+ const legacy_hal::WifiPeerInfo& legacy_peer_info_stats,
+ StaPeerInfo* hidl_peer_info_stats) {
+ if (!hidl_peer_info_stats) {
+ return false;
+ }
+ *hidl_peer_info_stats = {};
+ hidl_peer_info_stats->staCount =
+ legacy_peer_info_stats.peer_info.bssload.sta_count;
+ hidl_peer_info_stats->chanUtil =
+ legacy_peer_info_stats.peer_info.bssload.chan_util;
+
+ std::vector<StaRateStat> hidlRateStats;
+ for (const auto& legacy_rate_stats : legacy_peer_info_stats.rate_stats) {
+ StaRateStat rateStat;
+ if (!convertLegacyWifiRateInfoToHidl(legacy_rate_stats.rate,
+ &rateStat.rateInfo)) {
+ return false;
+ }
+ rateStat.txMpdu = legacy_rate_stats.tx_mpdu;
+ rateStat.rxMpdu = legacy_rate_stats.rx_mpdu;
+ rateStat.mpduLost = legacy_rate_stats.mpdu_lost;
+ rateStat.retries = legacy_rate_stats.retries;
+ hidlRateStats.push_back(rateStat);
+ }
+ hidl_peer_info_stats->rateStats = hidlRateStats;
+ return true;
+}
+
bool convertLegacyRoamingCapabilitiesToHidl(
const legacy_hal::wifi_roaming_capabilities& legacy_caps,
StaRoamingCapabilities* hidl_caps) {
diff --git a/wifi/1.5/default/hidl_struct_util.h b/wifi/1.5/default/hidl_struct_util.h
index 8b81033..352f213 100644
--- a/wifi/1.5/default/hidl_struct_util.h
+++ b/wifi/1.5/default/hidl_struct_util.h
@@ -212,6 +212,11 @@
bool convertLegacyWifiUsableChannelsToHidl(
const std::vector<legacy_hal::wifi_usable_channel>& legacy_usable_channels,
std::vector<V1_5::WifiUsableChannel>* hidl_usable_channels);
+bool convertLegacyPeerInfoStatsToHidl(
+ const legacy_hal::WifiPeerInfo& legacy_peer_info_stats,
+ StaPeerInfo* hidl_peer_info_stats);
+bool convertLegacyWifiRateInfoToHidl(const legacy_hal::wifi_rate& legacy_rate,
+ V1_4::WifiRateInfo* hidl_rate);
} // namespace hidl_struct_util
} // namespace implementation
} // namespace V1_5
diff --git a/wifi/1.5/default/service.cpp b/wifi/1.5/default/service.cpp
index 23e2b47..3de49b2 100644
--- a/wifi/1.5/default/service.cpp
+++ b/wifi/1.5/default/service.cpp
@@ -32,7 +32,6 @@
using android::hardware::LazyServiceRegistrar;
using android::hardware::wifi::V1_5::implementation::feature_flags::
WifiFeatureFlags;
-using android::hardware::wifi::V1_5::implementation::iface_util::WifiIfaceUtil;
using android::hardware::wifi::V1_5::implementation::legacy_hal::WifiLegacyHal;
using android::hardware::wifi::V1_5::implementation::legacy_hal::
WifiLegacyHalFactory;
@@ -63,7 +62,6 @@
new android::hardware::wifi::V1_5::implementation::Wifi(
iface_tool, legacy_hal_factory,
std::make_shared<WifiModeController>(),
- std::make_shared<WifiIfaceUtil>(iface_tool),
std::make_shared<WifiFeatureFlags>());
if (kLazyService) {
auto registrar = LazyServiceRegistrar::getInstance();
diff --git a/wifi/1.5/default/tests/hidl_struct_util_unit_tests.cpp b/wifi/1.5/default/tests/hidl_struct_util_unit_tests.cpp
index 6391a6a..e70d7ba 100644
--- a/wifi/1.5/default/tests/hidl_struct_util_unit_tests.cpp
+++ b/wifi/1.5/default/tests/hidl_struct_util_unit_tests.cpp
@@ -132,6 +132,8 @@
legacy_hal::LinkLayerStats legacy_stats{};
legacy_stats.radios.push_back(legacy_hal::LinkLayerRadioStats{});
legacy_stats.radios.push_back(legacy_hal::LinkLayerRadioStats{});
+ legacy_stats.peers.push_back(legacy_hal::WifiPeerInfo{});
+ legacy_stats.peers.push_back(legacy_hal::WifiPeerInfo{});
legacy_stats.iface.beacon_rx = rand();
legacy_stats.iface.rssi_mgmt = rand();
legacy_stats.iface.ac[legacy_hal::WIFI_AC_BE].rx_mpdu = rand();
@@ -175,6 +177,7 @@
rand();
legacy_stats.iface.info.time_slicing_duty_cycle_percent = rand();
+ legacy_stats.iface.num_peers = 1;
for (auto& radio : legacy_stats.radios) {
radio.stats.on_time = rand();
@@ -204,6 +207,31 @@
radio.channel_stats.push_back(channel_stat2);
}
+ for (auto& peer : legacy_stats.peers) {
+ peer.peer_info.bssload.sta_count = rand();
+ peer.peer_info.bssload.chan_util = rand();
+ wifi_rate_stat rate_stat1 = {
+ .rate = {3, 1, 2, 5, 0, 0},
+ .tx_mpdu = 0,
+ .rx_mpdu = 1,
+ .mpdu_lost = 2,
+ .retries = 3,
+ .retries_short = 4,
+ .retries_long = 5,
+ };
+ wifi_rate_stat rate_stat2 = {
+ .rate = {2, 2, 1, 6, 0, 1},
+ .tx_mpdu = 6,
+ .rx_mpdu = 7,
+ .mpdu_lost = 8,
+ .retries = 9,
+ .retries_short = 10,
+ .retries_long = 11,
+ };
+ peer.rate_stats.push_back(rate_stat1);
+ peer.rate_stats.push_back(rate_stat2);
+ }
+
V1_5::StaLinkLayerStats converted{};
hidl_struct_util::convertLegacyLinkLayerStatsToHidl(legacy_stats,
&converted);
@@ -330,6 +358,37 @@
converted.radios[i].channelStats[k].onTimeInMs);
}
}
+
+ EXPECT_EQ(legacy_stats.peers.size(), converted.iface.peers.size());
+ for (size_t i = 0; i < legacy_stats.peers.size(); i++) {
+ EXPECT_EQ(legacy_stats.peers[i].peer_info.bssload.sta_count,
+ converted.iface.peers[i].staCount);
+ EXPECT_EQ(legacy_stats.peers[i].peer_info.bssload.chan_util,
+ converted.iface.peers[i].chanUtil);
+ for (size_t j = 0; j < legacy_stats.peers[i].rate_stats.size(); j++) {
+ EXPECT_EQ(legacy_stats.peers[i].rate_stats[j].rate.preamble,
+ (uint32_t)converted.iface.peers[i]
+ .rateStats[j]
+ .rateInfo.preamble);
+ EXPECT_EQ(
+ legacy_stats.peers[i].rate_stats[j].rate.nss,
+ (uint32_t)converted.iface.peers[i].rateStats[j].rateInfo.nss);
+ EXPECT_EQ(
+ legacy_stats.peers[i].rate_stats[j].rate.bw,
+ (uint32_t)converted.iface.peers[i].rateStats[j].rateInfo.bw);
+ EXPECT_EQ(
+ legacy_stats.peers[i].rate_stats[j].rate.rateMcsIdx,
+ converted.iface.peers[i].rateStats[j].rateInfo.rateMcsIdx);
+ EXPECT_EQ(legacy_stats.peers[i].rate_stats[j].tx_mpdu,
+ converted.iface.peers[i].rateStats[j].txMpdu);
+ EXPECT_EQ(legacy_stats.peers[i].rate_stats[j].rx_mpdu,
+ converted.iface.peers[i].rateStats[j].rxMpdu);
+ EXPECT_EQ(legacy_stats.peers[i].rate_stats[j].mpdu_lost,
+ converted.iface.peers[i].rateStats[j].mpduLost);
+ EXPECT_EQ(legacy_stats.peers[i].rate_stats[j].retries,
+ converted.iface.peers[i].rateStats[j].retries);
+ }
+ }
}
TEST_F(HidlStructUtilTest, CanConvertLegacyFeaturesToHidl) {
diff --git a/wifi/1.5/default/tests/mock_wifi_iface_util.cpp b/wifi/1.5/default/tests/mock_wifi_iface_util.cpp
index fe6e9e2..b101c4a 100644
--- a/wifi/1.5/default/tests/mock_wifi_iface_util.cpp
+++ b/wifi/1.5/default/tests/mock_wifi_iface_util.cpp
@@ -29,8 +29,9 @@
namespace iface_util {
MockWifiIfaceUtil::MockWifiIfaceUtil(
- const std::weak_ptr<wifi_system::InterfaceTool> iface_tool)
- : WifiIfaceUtil(iface_tool) {}
+ const std::weak_ptr<wifi_system::InterfaceTool> iface_tool,
+ const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal)
+ : WifiIfaceUtil(iface_tool, legacy_hal) {}
} // namespace iface_util
} // namespace implementation
} // namespace V1_5
diff --git a/wifi/1.5/default/tests/mock_wifi_iface_util.h b/wifi/1.5/default/tests/mock_wifi_iface_util.h
index a719fec..6d5f59c 100644
--- a/wifi/1.5/default/tests/mock_wifi_iface_util.h
+++ b/wifi/1.5/default/tests/mock_wifi_iface_util.h
@@ -31,7 +31,8 @@
class MockWifiIfaceUtil : public WifiIfaceUtil {
public:
MockWifiIfaceUtil(
- const std::weak_ptr<wifi_system::InterfaceTool> iface_tool);
+ const std::weak_ptr<wifi_system::InterfaceTool> iface_tool,
+ const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal);
MOCK_METHOD1(getFactoryMacAddress,
std::array<uint8_t, 6>(const std::string&));
MOCK_METHOD2(setMacAddress,
diff --git a/wifi/1.5/default/tests/mock_wifi_legacy_hal.h b/wifi/1.5/default/tests/mock_wifi_legacy_hal.h
index 9ab2fd5..826b35f 100644
--- a/wifi/1.5/default/tests/mock_wifi_legacy_hal.h
+++ b/wifi/1.5/default/tests/mock_wifi_legacy_hal.h
@@ -62,6 +62,7 @@
wifi_error(const std::string& ifname,
wifi_interface_type iftype));
MOCK_METHOD1(deleteVirtualInterface, wifi_error(const std::string& ifname));
+ MOCK_METHOD0(waitForDriverReady, wifi_error());
};
} // namespace legacy_hal
} // namespace implementation
diff --git a/wifi/1.5/default/tests/wifi_chip_unit_tests.cpp b/wifi/1.5/default/tests/wifi_chip_unit_tests.cpp
index d99bfbd..0ad6f3e 100644
--- a/wifi/1.5/default/tests/wifi_chip_unit_tests.cpp
+++ b/wifi/1.5/default/tests/wifi_chip_unit_tests.cpp
@@ -276,7 +276,7 @@
std::shared_ptr<NiceMock<mode_controller::MockWifiModeController>>
mode_controller_{new NiceMock<mode_controller::MockWifiModeController>};
std::shared_ptr<NiceMock<iface_util::MockWifiIfaceUtil>> iface_util_{
- new NiceMock<iface_util::MockWifiIfaceUtil>(iface_tool_)};
+ new NiceMock<iface_util::MockWifiIfaceUtil>(iface_tool_, legacy_hal_)};
std::shared_ptr<NiceMock<feature_flags::MockWifiFeatureFlags>>
feature_flags_{new NiceMock<feature_flags::MockWifiFeatureFlags>};
diff --git a/wifi/1.5/default/tests/wifi_iface_util_unit_tests.cpp b/wifi/1.5/default/tests/wifi_iface_util_unit_tests.cpp
index d70e42f..8b67bb8 100644
--- a/wifi/1.5/default/tests/wifi_iface_util_unit_tests.cpp
+++ b/wifi/1.5/default/tests/wifi_iface_util_unit_tests.cpp
@@ -22,6 +22,7 @@
#include "wifi_iface_util.h"
#include "mock_interface_tool.h"
+#include "mock_wifi_legacy_hal.h"
using testing::NiceMock;
using testing::Test;
@@ -48,7 +49,11 @@
protected:
std::shared_ptr<NiceMock<wifi_system::MockInterfaceTool>> iface_tool_{
new NiceMock<wifi_system::MockInterfaceTool>};
- WifiIfaceUtil* iface_util_ = new WifiIfaceUtil(iface_tool_);
+ legacy_hal::wifi_hal_fn fake_func_table_;
+ std::shared_ptr<NiceMock<legacy_hal::MockWifiLegacyHal>> legacy_hal_{
+ new NiceMock<legacy_hal::MockWifiLegacyHal>(iface_tool_,
+ fake_func_table_, true)};
+ WifiIfaceUtil* iface_util_ = new WifiIfaceUtil(iface_tool_, legacy_hal_);
};
TEST_F(WifiIfaceUtilTest, GetOrCreateRandomMacAddress) {
diff --git a/wifi/1.5/default/tests/wifi_nan_iface_unit_tests.cpp b/wifi/1.5/default/tests/wifi_nan_iface_unit_tests.cpp
index 52f0c2b..32da55e 100644
--- a/wifi/1.5/default/tests/wifi_nan_iface_unit_tests.cpp
+++ b/wifi/1.5/default/tests/wifi_nan_iface_unit_tests.cpp
@@ -122,7 +122,7 @@
new NiceMock<legacy_hal::MockWifiLegacyHal>(iface_tool_,
fake_func_table_, true)};
std::shared_ptr<NiceMock<iface_util::MockWifiIfaceUtil>> iface_util_{
- new NiceMock<iface_util::MockWifiIfaceUtil>(iface_tool_)};
+ new NiceMock<iface_util::MockWifiIfaceUtil>(iface_tool_, legacy_hal_)};
};
TEST_F(WifiNanIfaceTest, IfacEventHandlers_OnStateToggleOffOn) {
diff --git a/wifi/1.5/default/wifi.cpp b/wifi/1.5/default/wifi.cpp
index 17db51d..da98db8 100644
--- a/wifi/1.5/default/wifi.cpp
+++ b/wifi/1.5/default/wifi.cpp
@@ -37,12 +37,10 @@
const std::shared_ptr<wifi_system::InterfaceTool> iface_tool,
const std::shared_ptr<legacy_hal::WifiLegacyHalFactory> legacy_hal_factory,
const std::shared_ptr<mode_controller::WifiModeController> mode_controller,
- const std::shared_ptr<iface_util::WifiIfaceUtil> iface_util,
const std::shared_ptr<feature_flags::WifiFeatureFlags> feature_flags)
: iface_tool_(iface_tool),
legacy_hal_factory_(legacy_hal_factory),
mode_controller_(mode_controller),
- iface_util_(iface_util),
feature_flags_(feature_flags),
run_state_(RunState::STOPPED) {}
@@ -130,7 +128,8 @@
for (auto& hal : legacy_hals_) {
chips_.push_back(new WifiChip(
chipId, chipId == kPrimaryChipId, hal, mode_controller_,
- iface_util_, feature_flags_, on_subsystem_restart_callback));
+ std::make_shared<iface_util::WifiIfaceUtil>(iface_tool_, hal),
+ feature_flags_, on_subsystem_restart_callback));
chipId++;
}
run_state_ = RunState::STARTED;
diff --git a/wifi/1.5/default/wifi.h b/wifi/1.5/default/wifi.h
index 9f5a1b0..825c0bc 100644
--- a/wifi/1.5/default/wifi.h
+++ b/wifi/1.5/default/wifi.h
@@ -46,7 +46,6 @@
legacy_hal_factory,
const std::shared_ptr<mode_controller::WifiModeController>
mode_controller,
- const std::shared_ptr<iface_util::WifiIfaceUtil> iface_util,
const std::shared_ptr<feature_flags::WifiFeatureFlags> feature_flags);
bool isValid();
@@ -85,7 +84,6 @@
std::shared_ptr<legacy_hal::WifiLegacyHalFactory> legacy_hal_factory_;
std::shared_ptr<mode_controller::WifiModeController> mode_controller_;
std::vector<std::shared_ptr<legacy_hal::WifiLegacyHal>> legacy_hals_;
- std::shared_ptr<iface_util::WifiIfaceUtil> iface_util_;
std::shared_ptr<feature_flags::WifiFeatureFlags> feature_flags_;
RunState run_state_;
std::vector<sp<WifiChip>> chips_;
diff --git a/wifi/1.5/default/wifi_chip.cpp b/wifi/1.5/default/wifi_chip.cpp
index 0450a7b..0499f45 100644
--- a/wifi/1.5/default/wifi_chip.cpp
+++ b/wifi/1.5/default/wifi_chip.cpp
@@ -353,7 +353,7 @@
ChipId chip_id, bool is_primary,
const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal,
const std::weak_ptr<mode_controller::WifiModeController> mode_controller,
- const std::weak_ptr<iface_util::WifiIfaceUtil> iface_util,
+ const std::shared_ptr<iface_util::WifiIfaceUtil> iface_util,
const std::weak_ptr<feature_flags::WifiFeatureFlags> feature_flags,
const std::function<void(const std::string&)>& handler)
: chip_id_(chip_id),
@@ -986,14 +986,14 @@
}
}
br_ifaces_ap_instances_[br_ifname] = ap_instances;
- if (!iface_util_.lock()->createBridge(br_ifname)) {
+ if (!iface_util_->createBridge(br_ifname)) {
LOG(ERROR) << "Failed createBridge - br_name=" << br_ifname.c_str();
invalidateAndClearBridgedAp(br_ifname);
return {createWifiStatus(WifiStatusCode::ERROR_NOT_AVAILABLE), {}};
}
for (auto const& instance : ap_instances) {
// Bind ap instance interface to AP bridge
- if (!iface_util_.lock()->addIfaceToBridge(br_ifname, instance)) {
+ if (!iface_util_->addIfaceToBridge(br_ifname, instance)) {
LOG(ERROR) << "Failed add if to Bridge - if_name="
<< instance.c_str();
invalidateAndClearBridgedAp(br_ifname);
@@ -1054,8 +1054,7 @@
if (it.first == ifname) {
for (auto const& iface : it.second) {
if (iface == ifInstanceName) {
- if (!iface_util_.lock()->removeIfaceFromBridge(it.first,
- iface)) {
+ if (!iface_util_->removeIfaceFromBridge(it.first, iface)) {
LOG(ERROR)
<< "Failed to remove interface: " << ifInstanceName
<< " from " << ifname;
@@ -1086,7 +1085,7 @@
}
bool is_dedicated_iface = true;
std::string ifname = getPredefinedNanIfaceName();
- if (ifname.empty() || !iface_util_.lock()->ifNameToIndex(ifname)) {
+ if (ifname.empty() || !iface_util_->ifNameToIndex(ifname)) {
// Use the first shared STA iface (wlan0) if a dedicated aware iface is
// not defined.
ifname = getFirstActiveWlanIfaceName();
@@ -1968,10 +1967,10 @@
void WifiChip::invalidateAndClearBridgedApAll() {
for (auto const& it : br_ifaces_ap_instances_) {
for (auto const& iface : it.second) {
- iface_util_.lock()->removeIfaceFromBridge(it.first, iface);
+ iface_util_->removeIfaceFromBridge(it.first, iface);
legacy_hal_.lock()->deleteVirtualInterface(iface);
}
- iface_util_.lock()->deleteBridge(it.first);
+ iface_util_->deleteBridge(it.first);
}
br_ifaces_ap_instances_.clear();
}
@@ -1982,10 +1981,10 @@
for (auto const& it : br_ifaces_ap_instances_) {
if (it.first == br_name) {
for (auto const& iface : it.second) {
- iface_util_.lock()->removeIfaceFromBridge(br_name, iface);
+ iface_util_->removeIfaceFromBridge(br_name, iface);
legacy_hal_.lock()->deleteVirtualInterface(iface);
}
- iface_util_.lock()->deleteBridge(br_name);
+ iface_util_->deleteBridge(br_name);
br_ifaces_ap_instances_.erase(br_name);
break;
}
diff --git a/wifi/1.5/default/wifi_chip.h b/wifi/1.5/default/wifi_chip.h
index b4ed30e..92d639f 100644
--- a/wifi/1.5/default/wifi_chip.h
+++ b/wifi/1.5/default/wifi_chip.h
@@ -54,7 +54,7 @@
const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal,
const std::weak_ptr<mode_controller::WifiModeController>
mode_controller,
- const std::weak_ptr<iface_util::WifiIfaceUtil> iface_util,
+ const std::shared_ptr<iface_util::WifiIfaceUtil> iface_util,
const std::weak_ptr<feature_flags::WifiFeatureFlags> feature_flags,
const std::function<void(const std::string&)>&
subsystemCallbackHandler);
@@ -307,7 +307,7 @@
ChipId chip_id_;
std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal_;
std::weak_ptr<mode_controller::WifiModeController> mode_controller_;
- std::weak_ptr<iface_util::WifiIfaceUtil> iface_util_;
+ std::shared_ptr<iface_util::WifiIfaceUtil> iface_util_;
std::vector<sp<WifiApIface>> ap_ifaces_;
std::vector<sp<WifiNanIface>> nan_ifaces_;
std::vector<sp<WifiP2pIface>> p2p_ifaces_;
diff --git a/wifi/1.5/default/wifi_iface_util.cpp b/wifi/1.5/default/wifi_iface_util.cpp
index 2a0aef8..d1434e3 100644
--- a/wifi/1.5/default/wifi_iface_util.cpp
+++ b/wifi/1.5/default/wifi_iface_util.cpp
@@ -41,8 +41,10 @@
namespace iface_util {
WifiIfaceUtil::WifiIfaceUtil(
- const std::weak_ptr<wifi_system::InterfaceTool> iface_tool)
+ const std::weak_ptr<wifi_system::InterfaceTool> iface_tool,
+ const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal)
: iface_tool_(iface_tool),
+ legacy_hal_(legacy_hal),
random_mac_address_(nullptr),
event_handlers_map_() {}
@@ -59,14 +61,20 @@
return false;
}
#endif
- if (!iface_tool_.lock()->SetMacAddress(iface_name.c_str(), mac)) {
- LOG(ERROR) << "SetMacAddress failed.";
- return false;
- }
+ bool success = iface_tool_.lock()->SetMacAddress(iface_name.c_str(), mac);
#ifndef WIFI_AVOID_IFACE_RESET_MAC_CHANGE
if (!iface_tool_.lock()->SetUpState(iface_name.c_str(), true)) {
- LOG(ERROR) << "SetUpState(true) failed.";
- return false;
+ LOG(ERROR) << "SetUpState(true) failed. Wait for driver ready.";
+ // Wait for driver ready and try to set iface UP again
+ if (legacy_hal_.lock()->waitForDriverReady() !=
+ legacy_hal::WIFI_SUCCESS) {
+ LOG(ERROR) << "SetUpState(true) wait for driver ready failed.";
+ return false;
+ }
+ if (!iface_tool_.lock()->SetUpState(iface_name.c_str(), true)) {
+ LOG(ERROR) << "SetUpState(true) failed after retry.";
+ return false;
+ }
}
#endif
IfaceEventHandlers event_handlers = {};
@@ -77,8 +85,12 @@
if (event_handlers.on_state_toggle_off_on != nullptr) {
event_handlers.on_state_toggle_off_on(iface_name);
}
- LOG(DEBUG) << "Successfully SetMacAddress.";
- return true;
+ if (!success) {
+ LOG(ERROR) << "SetMacAddress failed.";
+ } else {
+ LOG(DEBUG) << "SetMacAddress succeeded.";
+ }
+ return success;
}
std::array<uint8_t, 6> WifiIfaceUtil::getOrCreateRandomMacAddress() {
diff --git a/wifi/1.5/default/wifi_iface_util.h b/wifi/1.5/default/wifi_iface_util.h
index 296d182..b449077 100644
--- a/wifi/1.5/default/wifi_iface_util.h
+++ b/wifi/1.5/default/wifi_iface_util.h
@@ -21,6 +21,8 @@
#include <android/hardware/wifi/1.0/IWifi.h>
+#include "wifi_legacy_hal.h"
+
namespace android {
namespace hardware {
namespace wifi {
@@ -40,7 +42,8 @@
*/
class WifiIfaceUtil {
public:
- WifiIfaceUtil(const std::weak_ptr<wifi_system::InterfaceTool> iface_tool);
+ WifiIfaceUtil(const std::weak_ptr<wifi_system::InterfaceTool> iface_tool,
+ const std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal);
virtual ~WifiIfaceUtil() = default;
virtual std::array<uint8_t, 6> getFactoryMacAddress(
@@ -73,6 +76,7 @@
std::array<uint8_t, 6> createRandomMacAddress();
std::weak_ptr<wifi_system::InterfaceTool> iface_tool_;
+ std::weak_ptr<legacy_hal::WifiLegacyHal> legacy_hal_;
std::unique_ptr<std::array<uint8_t, 6>> random_mac_address_;
std::map<std::string, IfaceEventHandlers> event_handlers_map_;
};
diff --git a/wifi/1.5/default/wifi_legacy_hal.cpp b/wifi/1.5/default/wifi_legacy_hal.cpp
index f5ca753..45ad84b 100644
--- a/wifi/1.5/default/wifi_legacy_hal.cpp
+++ b/wifi/1.5/default/wifi_legacy_hal.cpp
@@ -476,6 +476,10 @@
bool WifiLegacyHal::isStarted() { return is_started_; }
+wifi_error WifiLegacyHal::waitForDriverReady() {
+ return global_func_table_.wifi_wait_for_driver_ready();
+}
+
std::pair<wifi_error, std::string> WifiLegacyHal::getDriverVersion(
const std::string& iface_name) {
std::array<char, kMaxVersionStringLength> buffer;
@@ -715,9 +719,29 @@
wifi_iface_stat* iface_stats_ptr, int num_radios,
wifi_radio_stat* radio_stats_ptr) {
wifi_radio_stat* l_radio_stats_ptr;
+ wifi_peer_info* l_peer_info_stats_ptr;
if (iface_stats_ptr != nullptr) {
link_stats_ptr->iface = *iface_stats_ptr;
+ l_peer_info_stats_ptr = iface_stats_ptr->peer_info;
+ for (uint32_t i = 0; i < iface_stats_ptr->num_peers; i++) {
+ WifiPeerInfo peer;
+ peer.peer_info = *l_peer_info_stats_ptr;
+ if (l_peer_info_stats_ptr->num_rate > 0) {
+ /* Copy the rate stats */
+ peer.rate_stats.assign(
+ l_peer_info_stats_ptr->rate_stats,
+ l_peer_info_stats_ptr->rate_stats +
+ l_peer_info_stats_ptr->num_rate);
+ }
+ peer.peer_info.num_rate = 0;
+ link_stats_ptr->peers.push_back(peer);
+ l_peer_info_stats_ptr =
+ (wifi_peer_info*)((u8*)l_peer_info_stats_ptr +
+ sizeof(wifi_peer_info) +
+ (sizeof(wifi_rate_stat) *
+ l_peer_info_stats_ptr->num_rate));
+ }
link_stats_ptr->iface.num_peers = 0;
} else {
LOG(ERROR) << "Invalid iface stats in link layer stats";
diff --git a/wifi/1.5/default/wifi_legacy_hal.h b/wifi/1.5/default/wifi_legacy_hal.h
index 03ca841..8ebc66a 100644
--- a/wifi/1.5/default/wifi_legacy_hal.h
+++ b/wifi/1.5/default/wifi_legacy_hal.h
@@ -340,9 +340,15 @@
std::vector<wifi_channel_stat> channel_stats;
};
+struct WifiPeerInfo {
+ wifi_peer_info peer_info;
+ std::vector<wifi_rate_stat> rate_stats;
+};
+
struct LinkLayerStats {
wifi_iface_stat iface;
std::vector<LinkLayerRadioStats> radios;
+ std::vector<WifiPeerInfo> peers;
};
#pragma GCC diagnostic pop
@@ -473,6 +479,7 @@
// using a predefined timeout.
virtual wifi_error stop(std::unique_lock<std::recursive_mutex>* lock,
const std::function<void()>& on_complete_callback);
+ virtual wifi_error waitForDriverReady();
// Checks if legacy HAL has successfully started
bool isStarted();
// Wrappers for all the functions in the legacy HAL function table.
diff --git a/wifi/1.5/types.hal b/wifi/1.5/types.hal
index e1c0d32..0543004 100644
--- a/wifi/1.5/types.hal
+++ b/wifi/1.5/types.hal
@@ -26,6 +26,7 @@
import @1.3::StaLinkLayerRadioStats;
import @1.0::WifiChannelInMhz;
import @1.0::WifiChannelWidthInMhz;
+import @1.4::WifiRateInfo;
/**
* Wifi bands defined in 80211 spec.
@@ -162,6 +163,54 @@
};
/**
+ * Per rate statistics. The rate is characterized by the combination of preamble, number of spatial
+ * streams, transmission bandwidth, and modulation and coding scheme (MCS).
+ */
+struct StaRateStat{
+ /**
+ * Wifi rate information: preamble, number of spatial streams, bandwidth, MCS, etc.
+ */
+ WifiRateInfo rateInfo;
+ /**
+ * Number of successfully transmitted data packets (ACK received)
+ */
+ uint32_t txMpdu;
+ /**
+ * Number of received data packets
+ */
+ uint32_t rxMpdu;
+ /**
+ * Number of data packet losses (no ACK)
+ */
+ uint32_t mpduLost;
+ /**
+ * Number of data packet retries
+ */
+ uint32_t retries;
+};
+
+/**
+ * Per peer statistics. The types of peer include the Access Point (AP), the Tunneled Direct Link
+ * Setup (TDLS), the Group Owner (GO), the Neighbor Awareness Networking (NAN), etc.
+ */
+struct StaPeerInfo {
+ /**
+ * Station count: The total number of stations currently associated with the peer.
+ */
+ uint16_t staCount;
+ /**
+ * Channel utilization: The percentage of time (normalized to 255, i.e., x% corresponds to
+ * (int) x * 255 / 100) that the medium is sensed as busy measured by either physical or
+ * virtual carrier sense (CS) mechanism.
+ */
+ uint16_t chanUtil;
+ /**
+ * Per rate statistics
+ */
+ vec<StaRateStat> rateStats;
+};
+
+/**
* Iface statistics for the current connection.
*/
struct StaLinkLayerIfaceStats {
@@ -197,6 +246,11 @@
* WME Voice (VO) Access Category (AC) contention time statistics.
*/
StaLinkLayerIfaceContentionTimeStats wmeVoContentionTimeStats;
+
+ /**
+ * Per peer statistics.
+ */
+ vec<StaPeerInfo> peers;
};
/**
diff --git a/wifi/supplicant/1.4/Android.bp b/wifi/supplicant/1.4/Android.bp
index b486687..c988fdb 100644
--- a/wifi/supplicant/1.4/Android.bp
+++ b/wifi/supplicant/1.4/Android.bp
@@ -16,6 +16,7 @@
"types.hal",
"ISupplicant.hal",
"ISupplicantP2pIface.hal",
+ "ISupplicantP2pIfaceCallback.hal",
"ISupplicantStaIface.hal",
"ISupplicantStaNetwork.hal",
"ISupplicantStaNetworkCallback.hal",
diff --git a/wifi/supplicant/1.4/ISupplicantP2pIface.hal b/wifi/supplicant/1.4/ISupplicantP2pIface.hal
index 65c761d..28846de 100644
--- a/wifi/supplicant/1.4/ISupplicantP2pIface.hal
+++ b/wifi/supplicant/1.4/ISupplicantP2pIface.hal
@@ -17,6 +17,7 @@
package android.hardware.wifi.supplicant@1.4;
import @1.2::ISupplicantP2pIface;
+import ISupplicantP2pIfaceCallback;
/**
* Interface exposed by the supplicant for each P2P mode network
@@ -48,4 +49,36 @@
* @return enabled true if set, false otherwise.
*/
getEdmg() generates (SupplicantStatus status, bool enabled);
+
+ /**
+ * Register for callbacks from this interface.
+ *
+ * These callbacks are invoked for events that are specific to this interface.
+ * Registration of multiple callback objects is supported. These objects must
+ * be automatically deleted when the corresponding client process is dead or
+ * if this interface is removed.
+ *
+ * @param callback An instance of the |ISupplicantP2pIfaceCallback| HIDL
+ * interface object.
+ * @return status Status of the operation.
+ * Possible status codes:
+ * |SupplicantStatusCode.SUCCESS|,
+ * |SupplicantStatusCode.FAILURE_UNKNOWN|,
+ * |SupplicantStatusCode.FAILURE_IFACE_INVALID|
+ */
+ registerCallback_1_4(ISupplicantP2pIfaceCallback callback)
+ generates (SupplicantStatus status);
+
+ /*
+ * Set Wifi Display R2 device info.
+ *
+ * @param info WFD R2 device info as described in section 5.1.12 of WFD technical
+ * specification v2.1.
+ * @return status Status of the operation.
+ * Possible status codes:
+ * |SupplicantStatusCode.SUCCESS|,
+ * |SupplicantStatusCode.FAILURE_UNKNOWN|,
+ * |SupplicantStatusCode.FAILURE_IFACE_INVALID|
+ */
+ setWfdR2DeviceInfo(uint8_t[4] info) generates (SupplicantStatus status);
};
diff --git a/wifi/supplicant/1.4/ISupplicantP2pIfaceCallback.hal b/wifi/supplicant/1.4/ISupplicantP2pIfaceCallback.hal
new file mode 100644
index 0000000..a091274
--- /dev/null
+++ b/wifi/supplicant/1.4/ISupplicantP2pIfaceCallback.hal
@@ -0,0 +1,61 @@
+/*
+ * Copyright 2021 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package android.hardware.wifi.supplicant@1.4;
+
+import @1.0::ISupplicantP2pIfaceCallback;
+import @1.0::MacAddress;
+import @1.0::WpsConfigMethods;
+import @1.0::P2pGroupCapabilityMask;
+
+/**
+ * Callback Interface exposed by the supplicant service
+ * for each P2P mode interface (ISupplicantP2pIface).
+ *
+ * Clients need to host an instance of this HIDL interface object and
+ * pass a reference of the object to the supplicant via the
+ * corresponding |ISupplicantP2pIface.registerCallback| method.
+ */
+interface ISupplicantP2pIfaceCallback extends @1.0::ISupplicantP2pIfaceCallback {
+ /**
+ * Used to indicate that a P2P Wi-Fi Display R2 device has been found. Refer to
+ * Wi-Fi Display Technical Specification Version 2.0.
+ *
+ * @param srcAddress MAC address of the device found. This must either
+ * be the P2P device address for a peer which is not in a group,
+ * or the P2P interface address for a peer which is a Group Owner.
+ * @param p2pDeviceAddress P2P device address.
+ * @param primaryDeviceType Type of device. Refer to section B.1 of Wifi P2P
+ * Technical specification v1.2.
+ * @param deviceName Name of the device.
+ * @param configMethods Mask of WPS configuration methods supported by the
+ * device.
+ * @param deviceCapabilities Refer to section 4.1.4 of Wifi P2P Technical
+ * specification v1.2.
+ * @param groupCapabilites Refer to section 4.1.4 of Wifi P2P Technical
+ * specification v1.2.
+ * @param wfdDeviceInfo WFD device info as described in section 5.1.2 of WFD
+ * technical specification v1.0.0.
+ * @param wfdR2DeviceInfo WFD R2 device info as described in section 5.1.12 of WFD
+ * technical specification v2.1.
+ */
+ oneway onR2DeviceFound(
+ MacAddress srcAddress, MacAddress p2pDeviceAddress,
+ uint8_t[8] primaryDeviceType, string deviceName,
+ bitfield<WpsConfigMethods> configMethods, uint8_t deviceCapabilities,
+ bitfield<P2pGroupCapabilityMask> groupCapabilities, uint8_t[6] wfdDeviceInfo,
+ uint8_t[2] wfdR2DeviceInfo);
+};
diff --git a/wifi/supplicant/1.4/ISupplicantStaNetwork.hal b/wifi/supplicant/1.4/ISupplicantStaNetwork.hal
index 6bed5ab..4f95213 100644
--- a/wifi/supplicant/1.4/ISupplicantStaNetwork.hal
+++ b/wifi/supplicant/1.4/ISupplicantStaNetwork.hal
@@ -55,6 +55,24 @@
};
/**
+ * SAE Hash-to-Element mode.
+ */
+ enum SaeH2eMode : uint8_t {
+ /**
+ * Hash-to-Element is disabled, only Hunting & Pecking is allowed.
+ */
+ DISABLED,
+ /**
+ * Both Hash-to-Element and Hunting & Pecking are allowed.
+ */
+ H2E_OPTIONAL,
+ /**
+ * Only Hash-to-Element is allowed.
+ */
+ H2E_MANDATORY,
+ };
+
+ /**
* Set group cipher mask for the network.
*
* @param groupCipherMask value to set.
@@ -154,22 +172,16 @@
generates (SupplicantStatus status);
/**
- * Set whether to enable SAE H2E (Hash-to-Element) only mode.
+ * Set SAE H2E (Hash-to-Element) mode.
*
- * When enabled, only SAE H2E network is allowed; othewise
- * H&P (Hunting and Pecking) and H2E are both allowed.
- * H&P only mode is not supported.
- * If this API is not called before connecting to a SAE
- * network, the behavior is undefined.
- *
- * @param enable true to set, false otherwise.
+ * @param mode SAE H2E supporting mode.
* @return status Status of the operation.
* Possible status codes:
* |SupplicantStatusCode.SUCCESS|,
* |SupplicantStatusCode.FAILURE_UNKNOWN|,
* |SupplicantStatusCode.FAILURE_NETWORK_INVALID|
*/
- enableSaeH2eOnlyMode(bool enable) generates (SupplicantStatus status);
+ setSaeH2eMode(SaeH2eMode mode) generates (SupplicantStatus status);
/**
* Set whether to enable SAE PK (Public Key) only mode to enable public AP validation.
diff --git a/wifi/supplicant/1.4/types.hal b/wifi/supplicant/1.4/types.hal
index c39de6e..b72eb42 100644
--- a/wifi/supplicant/1.4/types.hal
+++ b/wifi/supplicant/1.4/types.hal
@@ -107,6 +107,10 @@
* WPA3 SAE Public-Key.
*/
SAE_PK = 1 << 2,
+ /**
+ * Wi-Fi Display R2
+ */
+ WFD_R2 = 1 << 3,
};
/**
diff --git a/wifi/supplicant/1.4/vts/functional/supplicant_p2p_iface_hidl_test.cpp b/wifi/supplicant/1.4/vts/functional/supplicant_p2p_iface_hidl_test.cpp
index 9185279..4427c390 100644
--- a/wifi/supplicant/1.4/vts/functional/supplicant_p2p_iface_hidl_test.cpp
+++ b/wifi/supplicant/1.4/vts/functional/supplicant_p2p_iface_hidl_test.cpp
@@ -28,16 +28,23 @@
#include "supplicant_hidl_test_utils_1_4.h"
using ::android::sp;
+using ::android::hardware::hidl_array;
+using ::android::hardware::hidl_string;
+using ::android::hardware::hidl_vec;
+using ::android::hardware::Return;
using ::android::hardware::Void;
using ::android::hardware::wifi::supplicant::V1_0::SupplicantStatus;
using ::android::hardware::wifi::supplicant::V1_0::SupplicantStatusCode;
using ::android::hardware::wifi::supplicant::V1_4::ISupplicantP2pIface;
+using ::android::hardware::wifi::supplicant::V1_4::ISupplicantP2pIfaceCallback;
using SupplicantStatusV1_4 =
::android::hardware::wifi::supplicant::V1_4::SupplicantStatus;
using SupplicantStatusCodeV1_4 =
::android::hardware::wifi::supplicant::V1_4::SupplicantStatusCode;
+constexpr uint8_t kTestWfdR2DeviceInfo[] = {[0 ... 3] = 0x01};
+
class SupplicantP2pIfaceHidlTest : public SupplicantHidlTestBaseV1_4 {
public:
virtual void SetUp() override {
@@ -51,6 +58,100 @@
sp<ISupplicantP2pIface> p2p_iface_;
};
+class IfaceCallback : public ISupplicantP2pIfaceCallback {
+ Return<void> onNetworkAdded(uint32_t /* id */) override { return Void(); }
+ Return<void> onNetworkRemoved(uint32_t /* id */) override { return Void(); }
+ Return<void> onDeviceFound(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */,
+ const hidl_array<uint8_t, 8>& /* primaryDeviceType */,
+ const hidl_string& /* deviceName */, uint16_t /* configMethods */,
+ uint8_t /* deviceCapabilities */, uint32_t /* groupCapabilities */,
+ const hidl_array<uint8_t, 6>& /* wfdDeviceInfo */) override {
+ return Void();
+ }
+ Return<void> onDeviceLost(
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */) override {
+ return Void();
+ }
+ Return<void> onFindStopped() override { return Void(); }
+ Return<void> onGoNegotiationRequest(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ ISupplicantP2pIfaceCallback::WpsDevPasswordId /* passwordId */)
+ override {
+ return Void();
+ }
+ Return<void> onGoNegotiationCompleted(
+ ISupplicantP2pIfaceCallback::P2pStatusCode /* status */) override {
+ return Void();
+ }
+ Return<void> onGroupFormationSuccess() override { return Void(); }
+ Return<void> onGroupFormationFailure(
+ const hidl_string& /* failureReason */) override {
+ return Void();
+ }
+ Return<void> onGroupStarted(
+ const hidl_string& /* groupIfname */, bool /* isGo */,
+ const hidl_vec<uint8_t>& /* ssid */, uint32_t /* frequency */,
+ const hidl_array<uint8_t, 32>& /* psk */,
+ const hidl_string& /* passphrase */,
+ const hidl_array<uint8_t, 6>& /* goDeviceAddress */,
+ bool /* isPersistent */) override {
+ return Void();
+ }
+ Return<void> onGroupRemoved(const hidl_string& /* groupIfname */,
+ bool /* isGo */) override {
+ return Void();
+ }
+ Return<void> onInvitationReceived(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ const hidl_array<uint8_t, 6>& /* goDeviceAddress */,
+ const hidl_array<uint8_t, 6>& /* bssid */,
+ uint32_t /* persistentNetworkId */,
+ uint32_t /* operatingFrequency */) override {
+ return Void();
+ }
+ Return<void> onInvitationResult(
+ const hidl_array<uint8_t, 6>& /* bssid */,
+ ISupplicantP2pIfaceCallback::P2pStatusCode /* status */) override {
+ return Void();
+ }
+ Return<void> onProvisionDiscoveryCompleted(
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */,
+ bool /* isRequest */,
+ ISupplicantP2pIfaceCallback::P2pProvDiscStatusCode /* status */,
+ uint16_t /* configMethods */,
+ const hidl_string& /* generatedPin */) override {
+ return Void();
+ }
+ Return<void> onServiceDiscoveryResponse(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ uint16_t /* updateIndicator */,
+ const hidl_vec<uint8_t>& /* tlvs */) override {
+ return Void();
+ }
+ Return<void> onStaAuthorized(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */) override {
+ return Void();
+ }
+ Return<void> onStaDeauthorized(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */) override {
+ return Void();
+ }
+ Return<void> onR2DeviceFound(
+ const hidl_array<uint8_t, 6>& /* srcAddress */,
+ const hidl_array<uint8_t, 6>& /* p2pDeviceAddress */,
+ const hidl_array<uint8_t, 8>& /* primaryDeviceType */,
+ const hidl_string& /* deviceName */, uint16_t /* configMethods */,
+ uint8_t /* deviceCapabilities */, uint32_t /* groupCapabilities */,
+ const hidl_array<uint8_t, 6>& /* wfdDeviceInfo */,
+ const hidl_array<uint8_t, 2>& /* wfdR2DeviceInfo */) override {
+ return Void();
+ }
+};
+
/*
* SetGetEdmg
*/
@@ -71,6 +172,26 @@
});
}
+/*
+ * RegisterCallback_1_4
+ */
+TEST_P(SupplicantP2pIfaceHidlTest, RegisterCallback_1_4) {
+ p2p_iface_->registerCallback_1_4(
+ new IfaceCallback(), [](const SupplicantStatusV1_4& status) {
+ EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS, status.code);
+ });
+}
+
+/*
+ * SetWfdR2DeviceInfo
+ */
+TEST_P(SupplicantP2pIfaceHidlTest, SetWfdR2DeviceInfo) {
+ p2p_iface_->setWfdR2DeviceInfo(
+ kTestWfdR2DeviceInfo, [&](const SupplicantStatusV1_4& status) {
+ EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS, status.code);
+ });
+}
+
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SupplicantP2pIfaceHidlTest);
INSTANTIATE_TEST_CASE_P(
PerInstance, SupplicantP2pIfaceHidlTest,
diff --git a/wifi/supplicant/1.4/vts/functional/supplicant_sta_network_hidl_test.cpp b/wifi/supplicant/1.4/vts/functional/supplicant_sta_network_hidl_test.cpp
index 0e38c4b..e3fbaf3 100644
--- a/wifi/supplicant/1.4/vts/functional/supplicant_sta_network_hidl_test.cpp
+++ b/wifi/supplicant/1.4/vts/functional/supplicant_sta_network_hidl_test.cpp
@@ -42,6 +42,8 @@
using ::android::hardware::wifi::supplicant::V1_4::
ISupplicantStaNetworkCallback;
using ::android::hardware::wifi::V1_0::IWifi;
+using ISupplicantStaNetworkV1_4 =
+ ::android::hardware::wifi::supplicant::V1_4::ISupplicantStaNetwork;
using SupplicantStatusV1_4 =
::android::hardware::wifi::supplicant::V1_4::SupplicantStatus;
using SupplicantStatusCodeV1_4 =
@@ -110,15 +112,24 @@
}
/*
- * enable SAE H2E (Hash-to-Element) only mode
+ * set SAE H2E (Hash-to-Element) mode
*/
-TEST_P(SupplicantStaNetworkHidlTest, EnableSaeH2eOnlyMode) {
- v1_4->enableSaeH2eOnlyMode(true, [&](const SupplicantStatusV1_4& status) {
- EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS, status.code);
- });
- v1_4->enableSaeH2eOnlyMode(false, [&](const SupplicantStatusV1_4& status) {
- EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS, status.code);
- });
+TEST_P(SupplicantStaNetworkHidlTest, SetSaeH2eMode) {
+ v1_4->setSaeH2eMode(ISupplicantStaNetworkV1_4::SaeH2eMode::DISABLED,
+ [&](const SupplicantStatusV1_4& status) {
+ EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS,
+ status.code);
+ });
+ v1_4->setSaeH2eMode(ISupplicantStaNetworkV1_4::SaeH2eMode::H2E_MANDATORY,
+ [&](const SupplicantStatusV1_4& status) {
+ EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS,
+ status.code);
+ });
+ v1_4->setSaeH2eMode(ISupplicantStaNetworkV1_4::SaeH2eMode::H2E_OPTIONAL,
+ [&](const SupplicantStatusV1_4& status) {
+ EXPECT_EQ(SupplicantStatusCodeV1_4::SUCCESS,
+ status.code);
+ });
}
/*