Simplify/clarify watchdog code

- Use Debug trait object rather than arbitrary closure
- Combine bool+Instant into Option<Instant>
- Disambiguate "timeout"
- Add comments
- Simplify multi-argument matches

Test: keystore2_test
Test: libwatchdog_rs.test
Flag: None, refactor
Change-Id: Ieb16257c763fc2e04d592d97f341fea27aad726f
diff --git a/keystore2/src/maintenance.rs b/keystore2/src/maintenance.rs
index ba92399..6c07f0c 100644
--- a/keystore2/src/maintenance.rs
+++ b/keystore2/src/maintenance.rs
@@ -177,9 +177,7 @@
         let (km_dev, _, _) =
             get_keymint_device(&sec_level).context(ks_err!("getting keymint device"))?;
 
-        let _wp = wd::watch_millis_with("In call_with_watchdog", 500, move || {
-            format!("Seclevel: {:?} Op: {}", sec_level, name)
-        });
+        let _wp = wd::watch_millis_with("In call_with_watchdog", 500, (sec_level, name));
         map_km_error(op(km_dev)).with_context(|| ks_err!("calling {}", name))?;
         Ok(())
     }
diff --git a/keystore2/src/security_level.rs b/keystore2/src/security_level.rs
index 00e0480..8ce802e 100644
--- a/keystore2/src/security_level.rs
+++ b/keystore2/src/security_level.rs
@@ -109,14 +109,12 @@
 
     fn watch_millis(&self, id: &'static str, millis: u64) -> Option<wd::WatchPoint> {
         let sec_level = self.security_level;
-        wd::watch_millis_with(id, millis, move || format!("SecurityLevel {:?}", sec_level))
+        wd::watch_millis_with(id, millis, sec_level)
     }
 
     fn watch(&self, id: &'static str) -> Option<wd::WatchPoint> {
         let sec_level = self.security_level;
-        wd::watch_millis_with(id, wd::DEFAULT_TIMEOUT_MS, move || {
-            format!("SecurityLevel {:?}", sec_level)
-        })
+        wd::watch_millis_with(id, wd::DEFAULT_TIMEOUT_MS, sec_level)
     }
 
     fn store_new_key(
diff --git a/keystore2/src/service.rs b/keystore2/src/service.rs
index 3726358..b760a56 100644
--- a/keystore2/src/service.rs
+++ b/keystore2/src/service.rs
@@ -381,9 +381,7 @@
         &self,
         security_level: SecurityLevel,
     ) -> binder::Result<Strong<dyn IKeystoreSecurityLevel>> {
-        let _wp = wd::watch_millis_with("IKeystoreService::getSecurityLevel", 500, move || {
-            format!("security_level: {}", security_level.0)
-        });
+        let _wp = wd::watch_millis_with("IKeystoreService::getSecurityLevel", 500, security_level);
         self.get_security_level(security_level).map_err(into_logged_binder)
     }
     fn getKeyEntry(&self, key: &KeyDescriptor) -> binder::Result<KeyEntryResponse> {
diff --git a/keystore2/src/watchdog_helper.rs b/keystore2/src/watchdog_helper.rs
index 03c7740..1072ac0 100644
--- a/keystore2/src/watchdog_helper.rs
+++ b/keystore2/src/watchdog_helper.rs
@@ -43,14 +43,14 @@
         Watchdog::watch(&WD, id, DEFAULT_TIMEOUT)
     }
 
-    /// Like `watch_millis` but with a callback that is called every time a report
-    /// is printed about this watch point.
+    /// Like `watch_millis` but with context that is included every time a report is printed about
+    /// this watch point.
     pub fn watch_millis_with(
         id: &'static str,
         millis: u64,
-        callback: impl Fn() -> String + Send + 'static,
+        context: impl std::fmt::Debug + Send + 'static,
     ) -> Option<WatchPoint> {
-        Watchdog::watch_with(&WD, id, Duration::from_millis(millis), callback)
+        Watchdog::watch_with(&WD, id, Duration::from_millis(millis), context)
     }
 }
 
@@ -71,7 +71,7 @@
     pub fn watch_millis_with(
         _: &'static str,
         _: u64,
-        _: impl Fn() -> String + Send + 'static,
+        _: impl std::fmt::Debug + Send + 'static,
     ) -> Option<WatchPoint> {
         None
     }
diff --git a/keystore2/watchdog/src/lib.rs b/keystore2/watchdog/src/lib.rs
index fa4620a..1ac3ef3 100644
--- a/keystore2/watchdog/src/lib.rs
+++ b/keystore2/watchdog/src/lib.rs
@@ -58,59 +58,55 @@
 struct Record {
     started: Instant,
     deadline: Instant,
-    callback: Option<Box<dyn Fn() -> String + Send + 'static>>,
+    context: Option<Box<dyn std::fmt::Debug + Send + 'static>>,
 }
 
 struct WatchdogState {
     state: State,
     thread: Option<thread::JoinHandle<()>>,
-    timeout: Duration,
+    /// How long to wait before dropping the watchdog thread when idle.
+    idle_timeout: Duration,
     records: HashMap<Index, Record>,
-    last_report: Instant,
-    has_overdue: bool,
+    last_report: Option<Instant>,
 }
 
 impl WatchdogState {
-    fn update_overdue_and_find_next_timeout(&mut self) -> (bool, Option<Duration>) {
+    fn overdue_and_next_timeout(&self) -> (bool, Option<Duration>) {
         let now = Instant::now();
         let mut next_timeout: Option<Duration> = None;
         let mut has_overdue = false;
         for (_, r) in self.records.iter() {
             let timeout = r.deadline.saturating_duration_since(now);
             if timeout == Duration::new(0, 0) {
+                // This timeout has passed.
                 has_overdue = true;
-                continue;
+            } else {
+                // This timeout is still to come; see if it's the closest one to now.
+                next_timeout = match next_timeout {
+                    Some(nt) if timeout < nt => Some(timeout),
+                    Some(nt) => Some(nt),
+                    None => Some(timeout),
+                };
             }
-            next_timeout = match next_timeout {
-                Some(nt) => {
-                    if timeout < nt {
-                        Some(timeout)
-                    } else {
-                        Some(nt)
-                    }
-                }
-                None => Some(timeout),
-            };
         }
         (has_overdue, next_timeout)
     }
 
-    fn log_report(&mut self, has_overdue: bool) -> bool {
-        match (self.has_overdue, has_overdue) {
-            (true, true) => {
-                if self.last_report.elapsed() < Watchdog::NOISY_REPORT_TIMEOUT {
-                    self.has_overdue = false;
-                    return false;
-                }
-            }
-            (_, false) => {
-                self.has_overdue = false;
-                return false;
-            }
-            (false, true) => {}
+    fn log_report(&mut self, has_overdue: bool) {
+        if !has_overdue {
+            // Nothing to report.
+            self.last_report = None;
+            return;
         }
-        self.last_report = Instant::now();
-        self.has_overdue = has_overdue;
+        // Something to report...
+        if let Some(reported_at) = self.last_report {
+            if reported_at.elapsed() < Watchdog::NOISY_REPORT_TIMEOUT {
+                // .. but it's too soon since the last report.
+                self.last_report = None;
+                return;
+            }
+        }
+        self.last_report = Some(Instant::now());
         log::warn!("### Keystore Watchdog report - BEGIN ###");
 
         let now = Instant::now();
@@ -149,15 +145,15 @@
 
         for g in groups.iter() {
             for (i, r) in g.iter() {
-                match &r.callback {
-                    Some(cb) => {
+                match &r.context {
+                    Some(ctx) => {
                         log::warn!(
-                            "{:?} {} Pending: {:?} Overdue {:?}: {}",
+                            "{:?} {} Pending: {:?} Overdue {:?} for {:?}",
                             i.tid,
                             i.id,
                             r.started.elapsed(),
                             r.deadline.elapsed(),
-                            (cb)()
+                            ctx
                         );
                     }
                     None => {
@@ -173,7 +169,6 @@
             }
         }
         log::warn!("### Keystore Watchdog report - END ###");
-        true
     }
 
     fn disarm(&mut self, index: Index) {
@@ -199,67 +194,65 @@
     /// at least every `NOISY_REPORT_TIMEOUT` interval.
     const NOISY_REPORT_TIMEOUT: Duration = Duration::from_secs(1);
 
-    /// Construct a [`Watchdog`]. When `timeout` has elapsed since the watchdog thread became
+    /// Construct a [`Watchdog`]. When `idle_timeout` has elapsed since the watchdog thread became
     /// idle, i.e., there are no more active or overdue watch points, the watchdog thread
     /// terminates.
-    pub fn new(timeout: Duration) -> Arc<Self> {
+    pub fn new(idle_timeout: Duration) -> Arc<Self> {
         Arc::new(Self {
             state: Arc::new((
                 Condvar::new(),
                 Mutex::new(WatchdogState {
                     state: State::NotRunning,
                     thread: None,
-                    timeout,
+                    idle_timeout,
                     records: HashMap::new(),
-                    last_report: Instant::now(),
-                    has_overdue: false,
+                    last_report: None,
                 }),
             )),
         })
     }
 
     fn watch_with_optional(
-        wd: &Arc<Self>,
-        callback: Option<Box<dyn Fn() -> String + Send + 'static>>,
+        wd: Arc<Self>,
+        context: Option<Box<dyn std::fmt::Debug + Send + 'static>>,
         id: &'static str,
         timeout: Duration,
     ) -> Option<WatchPoint> {
-        let deadline = Instant::now().checked_add(timeout);
-        if deadline.is_none() {
+        let Some(deadline) = Instant::now().checked_add(timeout) else {
             log::warn!("Deadline computation failed for WatchPoint \"{}\"", id);
             log::warn!("WatchPoint not armed.");
             return None;
-        }
-        wd.arm(callback, id, deadline.unwrap());
-        Some(WatchPoint { id, wd: wd.clone(), not_send: Default::default() })
+        };
+        wd.arm(context, id, deadline);
+        Some(WatchPoint { id, wd, not_send: Default::default() })
     }
 
     /// Create a new watch point. If the WatchPoint is not dropped before the timeout
     /// expires, a report is logged at least every second, which includes the id string
-    /// and whatever string the callback returns.
+    /// and any provided context.
     pub fn watch_with(
         wd: &Arc<Self>,
         id: &'static str,
         timeout: Duration,
-        callback: impl Fn() -> String + Send + 'static,
+        context: impl std::fmt::Debug + Send + 'static,
     ) -> Option<WatchPoint> {
-        Self::watch_with_optional(wd, Some(Box::new(callback)), id, timeout)
+        Self::watch_with_optional(wd.clone(), Some(Box::new(context)), id, timeout)
     }
 
-    /// Like `watch_with`, but without a callback.
+    /// Like `watch_with`, but without context.
     pub fn watch(wd: &Arc<Self>, id: &'static str, timeout: Duration) -> Option<WatchPoint> {
-        Self::watch_with_optional(wd, None, id, timeout)
+        Self::watch_with_optional(wd.clone(), None, id, timeout)
     }
 
     fn arm(
         &self,
-        callback: Option<Box<dyn Fn() -> String + Send + 'static>>,
+        context: Option<Box<dyn std::fmt::Debug + Send + 'static>>,
         id: &'static str,
         deadline: Instant,
     ) {
         let tid = thread::current().id();
         let index = Index { tid, id };
-        let record = Record { started: Instant::now(), deadline, callback };
+        let record = Record { started: Instant::now(), deadline, context };
 
         let (ref condvar, ref state) = *self.state;
 
@@ -297,17 +290,21 @@
             let mut state = state.lock().unwrap();
 
             loop {
-                let (has_overdue, next_timeout) = state.update_overdue_and_find_next_timeout();
+                let (has_overdue, next_timeout) = state.overdue_and_next_timeout();
                 state.log_report(has_overdue);
+
                 let (next_timeout, idle) = match (has_overdue, next_timeout) {
                     (true, Some(next_timeout)) => {
                         (min(next_timeout, Self::NOISY_REPORT_TIMEOUT), false)
                     }
-                    (false, Some(next_timeout)) => (next_timeout, false),
                     (true, None) => (Self::NOISY_REPORT_TIMEOUT, false),
-                    (false, None) => (state.timeout, true),
+                    (false, Some(next_timeout)) => (next_timeout, false),
+                    (false, None) => (state.idle_timeout, true),
                 };
 
+                // Wait until the closest timeout pops, but use a condition variable so that if a
+                // new watchpoint is started in the meanwhile it will interrupt the wait so we can
+                // recalculate.
                 let (s, timeout) = condvar.wait_timeout(state, next_timeout).unwrap();
                 state = s;
 
@@ -338,21 +335,37 @@
                 .with_max_level(log::LevelFilter::Debug),
         );
 
+        /// Count the number of times `Debug::fmt` is invoked.
+        #[derive(Default, Clone)]
+        struct DebugCounter(Arc<atomic::AtomicU8>);
+        impl DebugCounter {
+            fn value(&self) -> u8 {
+                self.0.load(atomic::Ordering::Relaxed)
+            }
+        }
+        impl std::fmt::Debug for DebugCounter {
+            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+                let count = self.0.fetch_add(1, atomic::Ordering::Relaxed);
+                write!(f, "hit_count: {count}")
+            }
+        }
+
         let wd = Watchdog::new(Watchdog::NOISY_REPORT_TIMEOUT.checked_mul(3).unwrap());
-        let hit_count = Arc::new(atomic::AtomicU8::new(0));
-        let hit_count_clone = hit_count.clone();
-        let wp =
-            Watchdog::watch_with(&wd, "test_watchdog", Duration::from_millis(100), move || {
-                format!("hit_count: {}", hit_count_clone.fetch_add(1, atomic::Ordering::Relaxed))
-            });
-        assert_eq!(0, hit_count.load(atomic::Ordering::Relaxed));
+        let hit_counter = DebugCounter::default();
+        let wp = Watchdog::watch_with(
+            &wd,
+            "test_watchdog",
+            Duration::from_millis(100),
+            hit_counter.clone(),
+        );
+        assert_eq!(0, hit_counter.value());
         thread::sleep(Duration::from_millis(500));
-        assert_eq!(1, hit_count.load(atomic::Ordering::Relaxed));
+        assert_eq!(1, hit_counter.value());
         thread::sleep(Watchdog::NOISY_REPORT_TIMEOUT);
-        assert_eq!(2, hit_count.load(atomic::Ordering::Relaxed));
+        assert_eq!(2, hit_counter.value());
         drop(wp);
         thread::sleep(Watchdog::NOISY_REPORT_TIMEOUT.checked_mul(4).unwrap());
-        assert_eq!(2, hit_count.load(atomic::Ordering::Relaxed));
+        assert_eq!(2, hit_counter.value());
         let (_, ref state) = *wd.state;
         let state = state.lock().unwrap();
         assert_eq!(state.state, State::NotRunning);