Fix wait/notify logic in GNSS VTS 2.0 test cases

Fixes: 131869042
Test: atest VtsHalGnssV2_0TargetTest
Change-Id: I6fe5713c0f1d329f2738a2d4ba4a7d5aa58efec9
diff --git a/gnss/2.0/vts/functional/gnss_hal_test.cpp b/gnss/2.0/vts/functional/gnss_hal_test.cpp
index da6092b..febd0f1 100644
--- a/gnss/2.0/vts/functional/gnss_hal_test.cpp
+++ b/gnss/2.0/vts/functional/gnss_hal_test.cpp
@@ -23,38 +23,34 @@
 using ::android::hardware::gnss::common::Utils;
 
 // Implementations for the main test class for GNSS HAL
-GnssHalTest::GnssHalTest()
-    : info_called_count_(0),
-      capabilities_called_count_(0),
-      measurement_corrections_capabilities_called_count_(0),
-      location_called_count_(0),
-      name_called_count_(0),
-      notify_count_(0) {}
+GnssHalTest::GnssHalTest() {}
 
 void GnssHalTest::SetUp() {
     gnss_hal_ = ::testing::VtsHalHidlTargetTestBase::getService<IGnss>(
         GnssHidlEnvironment::Instance()->getServiceName<IGnss>());
-    list_vec_gnss_sv_info_.clear();
     ASSERT_NE(gnss_hal_, nullptr);
 
     SetUpGnssCallback();
 }
 
 void GnssHalTest::TearDown() {
-    // Reset counters
-    info_called_count_ = 0;
-    capabilities_called_count_ = 0;
-    measurement_corrections_capabilities_called_count_ = 0;
-    location_called_count_ = 0;
-    name_called_count_ = 0;
-    measurement_called_count_ = 0;
-
     if (gnss_hal_ != nullptr) {
         gnss_hal_->cleanup();
     }
-    if (notify_count_ > 0) {
-        ALOGW("%d unprocessed callbacks discarded", notify_count_);
+
+    int unprocessedEventsCount = measurement_cbq_.size() + location_cbq_.size();
+    if (unprocessedEventsCount > 0) {
+        ALOGW("%d unprocessed callbacks discarded", unprocessedEventsCount);
     }
+
+    // Reset all callback event queues.
+    info_cbq_.reset();
+    name_cbq_.reset();
+    top_hal_capabilities_cbq_.reset();
+    measurement_corrections_capabilities_cbq_.reset();
+    measurement_cbq_.reset();
+    location_cbq_.reset();
+    sv_info_cbq_.reset();
 }
 
 void GnssHalTest::SetUpGnssCallback() {
@@ -72,13 +68,13 @@
     /*
      * All capabilities, name and systemInfo callbacks should trigger
      */
-    EXPECT_EQ(std::cv_status::no_timeout, wait(TIMEOUT_SEC));
-    EXPECT_EQ(std::cv_status::no_timeout, wait(TIMEOUT_SEC));
-    EXPECT_EQ(std::cv_status::no_timeout, wait(TIMEOUT_SEC));
+    EXPECT_TRUE(top_hal_capabilities_cbq_.retrieve(last_capabilities_, TIMEOUT_SEC));
+    EXPECT_TRUE(info_cbq_.retrieve(last_info_, TIMEOUT_SEC));
+    EXPECT_TRUE(name_cbq_.retrieve(last_name_, TIMEOUT_SEC));
 
-    EXPECT_EQ(capabilities_called_count_, 1);
-    EXPECT_EQ(info_called_count_, 1);
-    EXPECT_EQ(name_called_count_, 1);
+    EXPECT_EQ(top_hal_capabilities_cbq_.calledCount(), 1);
+    EXPECT_EQ(info_cbq_.calledCount(), 1);
+    EXPECT_EQ(name_cbq_.calledCount(), 1);
 }
 
 void GnssHalTest::StopAndClearLocations() {
@@ -92,9 +88,8 @@
      * the last reply for final startup messages to arrive (esp. system
      * info.)
      */
-    while (wait(TIMEOUT_SEC) == std::cv_status::no_timeout) {
-    }
-    location_called_count_ = 0;
+    location_cbq_.waitUntilEmpty(TIMEOUT_SEC);
+    location_cbq_.reset();
 }
 
 void GnssHalTest::SetPositionMode(const int min_interval_msec, const bool low_power_mode) {
@@ -121,10 +116,11 @@
      */
     const int kFirstGnssLocationTimeoutSeconds = 75;
 
-    wait(kFirstGnssLocationTimeoutSeconds);
-    EXPECT_EQ(location_called_count_, 1);
+    EXPECT_TRUE(location_cbq_.retrieve(last_location_, kFirstGnssLocationTimeoutSeconds));
+    int locationCalledCount = location_cbq_.calledCount();
+    EXPECT_EQ(locationCalledCount, 1);
 
-    if (location_called_count_ > 0) {
+    if (locationCalledCount > 0) {
         // don't require speed on first fix
         CheckLocation(last_location_, false);
         return true;
@@ -133,7 +129,7 @@
 }
 
 void GnssHalTest::CheckLocation(const GnssLocation_2_0& location, bool check_speed) {
-    const bool check_more_accuracies = (info_called_count_ > 0 && last_info_.yearOfHw >= 2017);
+    const bool check_more_accuracies = (info_cbq_.calledCount() > 0 && last_info_.yearOfHw >= 2017);
 
     Utils::checkLocation(location.v1_0, check_speed, check_more_accuracies);
 }
@@ -148,77 +144,39 @@
     EXPECT_TRUE(StartAndCheckFirstLocation());
 
     for (int i = 1; i < count; i++) {
-        EXPECT_EQ(std::cv_status::no_timeout, wait(kLocationTimeoutSubsequentSec));
-        EXPECT_EQ(location_called_count_, i + 1);
+        EXPECT_TRUE(location_cbq_.retrieve(last_location_, kLocationTimeoutSubsequentSec));
+        int locationCalledCount = location_cbq_.calledCount();
+        EXPECT_EQ(locationCalledCount, i + 1);
         // Don't cause confusion by checking details if no location yet
-        if (location_called_count_ > 0) {
+        if (locationCalledCount > 0) {
             // Should be more than 1 location by now, but if not, still don't check first fix speed
-            CheckLocation(last_location_, location_called_count_ > 1);
+            CheckLocation(last_location_, locationCalledCount > 1);
         }
     }
 }
 
-void GnssHalTest::notify() {
-    {
-        std::unique_lock<std::mutex> lock(mtx_);
-        notify_count_++;
-    }
-    cv_.notify_one();
-}
-
-std::cv_status GnssHalTest::wait(int timeout_seconds) {
-    std::unique_lock<std::mutex> lock(mtx_);
-
-    auto status = std::cv_status::no_timeout;
-    while (notify_count_ == 0) {
-        status = cv_.wait_for(lock, std::chrono::seconds(timeout_seconds));
-        if (status == std::cv_status::timeout) return status;
-    }
-    notify_count_--;
-    return status;
-}
-
-std::cv_status GnssHalTest::waitForMeasurementCorrectionsCapabilities(int timeout_seconds) {
-    std::unique_lock<std::mutex> lock(mtx_);
-    auto status = std::cv_status::no_timeout;
-    while (measurement_corrections_capabilities_called_count_ == 0) {
-        status = cv_.wait_for(lock, std::chrono::seconds(timeout_seconds));
-        if (status == std::cv_status::timeout) return status;
-    }
-    notify_count_--;
-    return status;
-}
-
 Return<void> GnssHalTest::GnssCallback::gnssSetSystemInfoCb(
         const IGnssCallback_1_0::GnssSystemInfo& info) {
     ALOGI("Info received, year %d", info.yearOfHw);
-    parent_.info_called_count_++;
-    parent_.last_info_ = info;
-    parent_.notify();
+    parent_.info_cbq_.store(info);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssSetCapabilitesCb(uint32_t capabilities) {
     ALOGI("Capabilities received %d", capabilities);
-    parent_.capabilities_called_count_++;
-    parent_.last_capabilities_ = capabilities;
-    parent_.notify();
+    parent_.top_hal_capabilities_cbq_.store(capabilities);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssSetCapabilitiesCb_2_0(uint32_t capabilities) {
     ALOGI("Capabilities (v2.0) received %d", capabilities);
-    parent_.capabilities_called_count_++;
-    parent_.last_capabilities_ = capabilities;
-    parent_.notify();
+    parent_.top_hal_capabilities_cbq_.store(capabilities);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssNameCb(const android::hardware::hidl_string& name) {
     ALOGI("Name received: %s", name.c_str());
-    parent_.name_called_count_++;
-    parent_.last_name_ = name;
-    parent_.notify();
+    parent_.name_cbq_.store(name);
     return Void();
 }
 
@@ -235,40 +193,32 @@
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssLocationCbImpl(const GnssLocation_2_0& location) {
-    parent_.location_called_count_++;
-    parent_.last_location_ = location;
-    parent_.notify();
+    parent_.location_cbq_.store(location);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssSvStatusCb(const IGnssCallback_1_0::GnssSvStatus&) {
     ALOGI("gnssSvStatusCb");
-
     return Void();
 }
 
 Return<void> GnssHalTest::GnssMeasurementCallback::gnssMeasurementCb_2_0(
     const IGnssMeasurementCallback_2_0::GnssData& data) {
     ALOGD("GnssMeasurement received. Size = %d", (int)data.measurements.size());
-    parent_.measurement_called_count_++;
-    parent_.last_measurement_ = data;
-    parent_.notify();
+    parent_.measurement_cbq_.store(data);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssMeasurementCorrectionsCallback::setCapabilitiesCb(
         uint32_t capabilities) {
     ALOGI("GnssMeasurementCorrectionsCallback capabilities received %d", capabilities);
-    parent_.measurement_corrections_capabilities_called_count_++;
-    parent_.last_measurement_corrections_capabilities_ = capabilities;
-    parent_.notify();
+    parent_.measurement_corrections_capabilities_cbq_.store(capabilities);
     return Void();
 }
 
 Return<void> GnssHalTest::GnssCallback::gnssSvStatusCb_2_0(
         const hidl_vec<IGnssCallback_2_0::GnssSvInfo>& svInfoList) {
     ALOGI("gnssSvStatusCb_2_0. Size = %d", (int)svInfoList.size());
-    parent_.list_vec_gnss_sv_info_.emplace_back(svInfoList);
-    parent_.notify();
+    parent_.sv_info_cbq_.store(svInfoList);
     return Void();
 }
diff --git a/gnss/2.0/vts/functional/gnss_hal_test.h b/gnss/2.0/vts/functional/gnss_hal_test.h
index 737815f..8e440ff 100644
--- a/gnss/2.0/vts/functional/gnss_hal_test.h
+++ b/gnss/2.0/vts/functional/gnss_hal_test.h
@@ -22,7 +22,7 @@
 #include <VtsHalHidlTargetTestEnvBase.h>
 
 #include <condition_variable>
-#include <list>
+#include <deque>
 #include <mutex>
 
 using android::hardware::hidl_vec;
@@ -125,7 +125,7 @@
 
     /* Callback class for GnssMeasurement. */
     class GnssMeasurementCallback : public IGnssMeasurementCallback_2_0 {
-       public:
+      public:
         GnssHalTest& parent_;
         GnssMeasurementCallback(GnssHalTest& parent) : parent_(parent){};
         virtual ~GnssMeasurementCallback() = default;
@@ -155,6 +155,77 @@
         Return<void> setCapabilitiesCb(uint32_t capabilities) override;
     };
 
+    /* Producer/consumer queue for storing/retrieving callback events from GNSS HAL */
+    template <class T>
+    class CallbackQueue {
+      public:
+        CallbackQueue() : called_count_(0){};
+
+        /* Adds callback event to the end of the queue. */
+        void store(const T& event) {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            events_.push_back(event);
+            ++called_count_;
+            lock.unlock();
+            cv_.notify_all();
+        }
+
+        /*
+         * Removes the callack event at the front of the queue, stores it in event parameter
+         * and returns true. If the timeout occurs waiting for callback event, returns false.
+         */
+        bool retrieve(T& event, int timeout_seconds) {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            cv_.wait_for(lock, std::chrono::seconds(timeout_seconds),
+                         [&] { return !events_.empty(); });
+            if (events_.empty()) {
+                return false;
+            }
+            event = events_.front();
+            events_.pop_front();
+            return true;
+        }
+
+        /* Returns the number of events pending to be retrieved from the callback event queue. */
+        int size() const {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            return events_.size();
+        }
+
+        /* Returns the number of callback events received since last reset(). */
+        int calledCount() const {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            return called_count_;
+        }
+
+        /* Clears the callback event queue and resets the calledCount() to 0. */
+        void reset() {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            events_.clear();
+            called_count_ = 0;
+        }
+
+        /*
+         * Blocks the calling thread until the callback event queue becomes empty or timeout
+         * occurs. Returns false on timeout.
+         */
+        bool waitUntilEmpty(int timeout_seconds) {
+            std::unique_lock<std::recursive_mutex> lock(mtx_);
+            cv_.wait_for(lock, std::chrono::seconds(timeout_seconds),
+                         [&] { return events_.empty(); });
+            return !events_.empty();
+        }
+
+      private:
+        CallbackQueue(const CallbackQueue&) = delete;
+        CallbackQueue& operator=(const CallbackQueue&) = delete;
+
+        mutable std::recursive_mutex mtx_;
+        std::condition_variable_any cv_;
+        std::deque<T> events_;
+        int called_count_;
+    };
+
     /*
      * SetUpGnssCallback:
      *   Set GnssCallback and verify the result.
@@ -205,30 +276,19 @@
     sp<IGnss> gnss_hal_;         // GNSS HAL to call into
     sp<IGnssCallback_2_0> gnss_cb_;  // Primary callback interface
 
-    // TODO: make these variables thread-safe.
-    /* Count of calls to set the following items, and the latest item (used by
-     * test.)
-     */
-    int info_called_count_;
-    int capabilities_called_count_;
-    int measurement_corrections_capabilities_called_count_;
-    int location_called_count_;
-    int measurement_called_count_;
-    int name_called_count_;
-
     IGnssCallback_1_0::GnssSystemInfo last_info_;
     uint32_t last_capabilities_;
     uint32_t last_measurement_corrections_capabilities_;
     GnssLocation_2_0 last_location_;
-    IGnssMeasurementCallback_2_0::GnssData last_measurement_;
     android::hardware::hidl_string last_name_;
 
-    list<hidl_vec<IGnssCallback_2_0::GnssSvInfo>> list_vec_gnss_sv_info_;
-
-  private:
-    std::mutex mtx_;
-    std::condition_variable cv_;
-    int notify_count_;
+    CallbackQueue<IGnssCallback_1_0::GnssSystemInfo> info_cbq_;
+    CallbackQueue<android::hardware::hidl_string> name_cbq_;
+    CallbackQueue<uint32_t> top_hal_capabilities_cbq_;
+    CallbackQueue<uint32_t> measurement_corrections_capabilities_cbq_;
+    CallbackQueue<IGnssMeasurementCallback_2_0::GnssData> measurement_cbq_;
+    CallbackQueue<GnssLocation_2_0> location_cbq_;
+    CallbackQueue<hidl_vec<IGnssCallback_2_0::GnssSvInfo>> sv_info_cbq_;
 };
 
 #endif  // GNSS_HAL_TEST_H_
diff --git a/gnss/2.0/vts/functional/gnss_hal_test_cases.cpp b/gnss/2.0/vts/functional/gnss_hal_test_cases.cpp
index 7c253b0..009f43d 100644
--- a/gnss/2.0/vts/functional/gnss_hal_test_cases.cpp
+++ b/gnss/2.0/vts/functional/gnss_hal_test_cases.cpp
@@ -199,10 +199,11 @@
     ASSERT_TRUE(result.isOk());
     EXPECT_EQ(result, IGnssMeasurement_1_0::GnssMeasurementStatus::SUCCESS);
 
-    wait(kFirstGnssMeasurementTimeoutSeconds);
-    EXPECT_EQ(measurement_called_count_, 1);
-    ASSERT_TRUE(last_measurement_.measurements.size() > 0);
-    for (auto measurement : last_measurement_.measurements) {
+    IGnssMeasurementCallback_2_0::GnssData lastMeasurement;
+    ASSERT_TRUE(measurement_cbq_.retrieve(lastMeasurement, kFirstGnssMeasurementTimeoutSeconds));
+    EXPECT_EQ(measurement_cbq_.calledCount(), 1);
+    ASSERT_TRUE(lastMeasurement.measurements.size() > 0);
+    for (auto measurement : lastMeasurement.measurements) {
         // Verify CodeType is valid.
         ASSERT_NE(measurement.codeType, "");
 
@@ -305,8 +306,10 @@
     iMeasurementCorrections->setCallback(iMeasurementCorrectionsCallback);
 
     const int kMeasurementCorrectionsCapabilitiesTimeoutSeconds = 5;
-    waitForMeasurementCorrectionsCapabilities(kMeasurementCorrectionsCapabilitiesTimeoutSeconds);
-    ASSERT_TRUE(measurement_corrections_capabilities_called_count_ > 0);
+    measurement_corrections_capabilities_cbq_.retrieve(
+            last_measurement_corrections_capabilities_,
+            kMeasurementCorrectionsCapabilitiesTimeoutSeconds);
+    ASSERT_TRUE(measurement_corrections_capabilities_cbq_.calledCount() > 0);
     using Capabilities = IMeasurementCorrectionsCallback::Capabilities;
     ASSERT_TRUE((last_measurement_corrections_capabilities_ &
                  (Capabilities::LOS_SATS | Capabilities::EXCESS_PATH_LENGTH)) != 0);
@@ -333,8 +336,11 @@
     iMeasurementCorrections->setCallback(iMeasurementCorrectionsCallback);
 
     const int kMeasurementCorrectionsCapabilitiesTimeoutSeconds = 5;
-    waitForMeasurementCorrectionsCapabilities(kMeasurementCorrectionsCapabilitiesTimeoutSeconds);
-    ASSERT_TRUE(measurement_corrections_capabilities_called_count_ > 0);
+    measurement_corrections_capabilities_cbq_.retrieve(
+            last_measurement_corrections_capabilities_,
+            kMeasurementCorrectionsCapabilitiesTimeoutSeconds);
+    ASSERT_TRUE(measurement_corrections_capabilities_cbq_.calledCount() > 0);
+
     // Set a mock MeasurementCorrections.
     auto result = iMeasurementCorrections->setCorrections(Utils::getMockMeasurementCorrections());
     ASSERT_TRUE(result.isOk());
@@ -365,16 +371,17 @@
     ASSERT_TRUE(result.isOk());
     EXPECT_EQ(result, IGnssMeasurement_1_0::GnssMeasurementStatus::SUCCESS);
 
-    wait(kFirstGnssMeasurementTimeoutSeconds);
-    EXPECT_EQ(measurement_called_count_, 1);
+    IGnssMeasurementCallback_2_0::GnssData lastMeasurement;
+    ASSERT_TRUE(measurement_cbq_.retrieve(lastMeasurement, kFirstGnssMeasurementTimeoutSeconds));
+    EXPECT_EQ(measurement_cbq_.calledCount(), 1);
 
-    ASSERT_TRUE((int)last_measurement_.elapsedRealtime.flags <=
+    ASSERT_TRUE((int)lastMeasurement.elapsedRealtime.flags <=
                 (int)(ElapsedRealtimeFlags::HAS_TIMESTAMP_NS |
                       ElapsedRealtimeFlags::HAS_TIME_UNCERTAINTY_NS));
 
     // We expect a non-zero timestamp when set.
-    if (last_measurement_.elapsedRealtime.flags & ElapsedRealtimeFlags::HAS_TIMESTAMP_NS) {
-        ASSERT_TRUE(last_measurement_.elapsedRealtime.timestampNs != 0);
+    if (lastMeasurement.elapsedRealtime.flags & ElapsedRealtimeFlags::HAS_TIMESTAMP_NS) {
+        ASSERT_TRUE(lastMeasurement.elapsedRealtime.timestampNs != 0);
     }
 
     iGnssMeasurement->close();