/*
 * Copyright 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

//! Slow keys input filter implementation.
//! Slow keys is an accessibility feature to aid users who have physical disabilities, that allows
//! the user to specify the duration for which one must press-and-hold a key before the system
//! accepts the keypress.
use crate::input_filter::{Filter, VIRTUAL_KEYBOARD_DEVICE_ID};
use crate::input_filter_thread::{InputFilterThread, ThreadCallback};
use android_hardware_input_common::aidl::android::hardware::input::common::Source::Source;
use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{
    DeviceInfo::DeviceInfo, KeyEvent::KeyEvent, KeyEventAction::KeyEventAction,
};
use input::KeyboardType;
use log::debug;
use std::any::Any;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};

// Policy flags from Input.h
const POLICY_FLAG_DISABLE_KEY_REPEAT: i32 = 0x08000000;

#[derive(Debug)]
struct OngoingKeyDown {
    scancode: i32,
    device_id: i32,
    down_time: i64,
}

struct SlowKeysFilterInner {
    next: Box<dyn Filter + Send + Sync>,
    slow_key_threshold_ns: i64,
    supported_devices: HashSet<i32>,
    // This tracks KeyEvents that are blocked by Slow keys filter and will be passed through if the
    // press duration exceeds the slow keys threshold.
    pending_down_events: Vec<KeyEvent>,
    // This tracks KeyEvent streams that have press duration greater than the slow keys threshold,
    // hence any future ACTION_DOWN (if repeats are handled on HW side) or ACTION_UP are allowed to
    // pass through without waiting.
    ongoing_down_events: Vec<OngoingKeyDown>,
    input_filter_thread: InputFilterThread,
}

#[derive(Clone)]
pub struct SlowKeysFilter(Arc<RwLock<SlowKeysFilterInner>>);

impl SlowKeysFilter {
    /// Create a new SlowKeysFilter instance.
    pub fn new(
        next: Box<dyn Filter + Send + Sync>,
        slow_key_threshold_ns: i64,
        input_filter_thread: InputFilterThread,
    ) -> SlowKeysFilter {
        let filter = Self(Arc::new(RwLock::new(SlowKeysFilterInner {
            next,
            slow_key_threshold_ns,
            supported_devices: HashSet::new(),
            pending_down_events: Vec::new(),
            ongoing_down_events: Vec::new(),
            input_filter_thread: input_filter_thread.clone(),
        })));
        input_filter_thread.register_thread_callback(Box::new(filter.clone()));
        filter
    }

    fn read_inner(&self) -> RwLockReadGuard<'_, SlowKeysFilterInner> {
        self.0.read().unwrap()
    }

    fn write_inner(&self) -> RwLockWriteGuard<'_, SlowKeysFilterInner> {
        self.0.write().unwrap()
    }

    fn request_next_callback(&self) {
        let slow_filter = &self.read_inner();
        if slow_filter.pending_down_events.is_empty() {
            return;
        }
        if let Some(event) = slow_filter.pending_down_events.iter().min_by_key(|x| x.downTime) {
            slow_filter.input_filter_thread.request_timeout_at_time(event.downTime);
        }
    }
}

impl Filter for SlowKeysFilter {
    fn notify_key(&mut self, event: &KeyEvent) {
        {
            // acquire write lock
            let mut slow_filter = self.write_inner();
            if !(slow_filter.supported_devices.contains(&event.deviceId)
                && event.source.0 & Source::KEYBOARD.0 != 0)
            {
                slow_filter.next.notify_key(event);
                return;
            }
            // Pass all events through if key down has already been processed
            // Do update the downtime before sending the events through
            if let Some(index) = slow_filter
                .ongoing_down_events
                .iter()
                .position(|x| x.device_id == event.deviceId && x.scancode == event.scanCode)
            {
                let mut new_event = *event;
                new_event.downTime = slow_filter.ongoing_down_events[index].down_time;
                slow_filter.next.notify_key(&new_event);
                if event.action == KeyEventAction::UP {
                    slow_filter.ongoing_down_events.remove(index);
                }
                return;
            }
            match event.action {
                KeyEventAction::DOWN => {
                    if slow_filter
                        .pending_down_events
                        .iter()
                        .any(|x| x.deviceId == event.deviceId && x.scanCode == event.scanCode)
                    {
                        debug!("Dropping key down event since another pending down event exists");
                        return;
                    }
                    let mut pending_event = *event;
                    pending_event.downTime += slow_filter.slow_key_threshold_ns;
                    pending_event.eventTime = pending_event.downTime;
                    // Currently a slow keys user ends up repeating the presses key quite often
                    // since default repeat thresholds are very low, so blocking repeat for events
                    // when slow keys is enabled.
                    // TODO(b/322327461): Allow key repeat with slow keys, once repeat key rate and
                    //  thresholds can be modified in the settings.
                    pending_event.policyFlags |= POLICY_FLAG_DISABLE_KEY_REPEAT;
                    slow_filter.pending_down_events.push(pending_event);
                }
                KeyEventAction::UP => {
                    debug!("Dropping key up event due to insufficient press duration");
                    if let Some(index) = slow_filter
                        .pending_down_events
                        .iter()
                        .position(|x| x.deviceId == event.deviceId && x.scanCode == event.scanCode)
                    {
                        slow_filter.pending_down_events.remove(index);
                    }
                }
                _ => (),
            }
        } // release write lock
        self.request_next_callback();
    }

    fn notify_devices_changed(&mut self, device_infos: &[DeviceInfo]) {
        let mut slow_filter = self.write_inner();
        slow_filter
            .pending_down_events
            .retain(|event| device_infos.iter().any(|x| event.deviceId == x.deviceId));
        slow_filter
            .ongoing_down_events
            .retain(|event| device_infos.iter().any(|x| event.device_id == x.deviceId));
        slow_filter.supported_devices.clear();
        for device_info in device_infos {
            if device_info.deviceId == VIRTUAL_KEYBOARD_DEVICE_ID {
                continue;
            }
            if device_info.keyboardType == KeyboardType::None as i32 {
                continue;
            }
            // Support Alphabetic keyboards and Non-alphabetic external keyboards
            if device_info.external || device_info.keyboardType == KeyboardType::Alphabetic as i32 {
                slow_filter.supported_devices.insert(device_info.deviceId);
            }
        }
        slow_filter.next.notify_devices_changed(device_infos);
    }

    fn destroy(&mut self) {
        let mut slow_filter = self.write_inner();
        slow_filter.input_filter_thread.unregister_thread_callback(Box::new(self.clone()));
        slow_filter.next.destroy();
    }

    fn save(
        &mut self,
        state: HashMap<&'static str, Box<dyn Any + Send + Sync>>,
    ) -> HashMap<&'static str, Box<dyn Any + Send + Sync>> {
        let mut slow_filter = self.write_inner();
        slow_filter.next.save(state)
    }

    fn restore(&mut self, state: &HashMap<&'static str, Box<dyn Any + Send + Sync>>) {
        let mut slow_filter = self.write_inner();
        slow_filter.next.restore(state);
    }

    fn dump(&mut self, dump_str: String) -> String {
        let mut slow_filter = self.write_inner();
        let mut result = "Slow Keys filter: \n".to_string();
        result += &format!("\tthreshold = {:?}ns\n", slow_filter.slow_key_threshold_ns);
        result += &format!("\tongoing_down_events = {:?}\n", slow_filter.ongoing_down_events);
        result += &format!("\tpending_down_events = {:?}\n", slow_filter.pending_down_events);
        result += &format!("\tsupported_devices = {:?}\n", slow_filter.supported_devices);
        slow_filter.next.dump(dump_str + &result)
    }
}

impl ThreadCallback for SlowKeysFilter {
    fn notify_timeout_expired(&self, when_nanos: i64) {
        {
            // acquire write lock
            let slow_filter = &mut self.write_inner();
            for event in slow_filter.pending_down_events.clone() {
                if event.downTime <= when_nanos {
                    slow_filter.next.notify_key(&event);
                    slow_filter.ongoing_down_events.push(OngoingKeyDown {
                        scancode: event.scanCode,
                        device_id: event.deviceId,
                        down_time: event.downTime,
                    });
                }
            }
            slow_filter.pending_down_events.retain(|event| event.downTime > when_nanos);
        } // release write lock
        self.request_next_callback();
    }

    fn name(&self) -> &str {
        "slow_keys_filter"
    }
}

#[cfg(test)]
mod tests {
    use crate::input_filter::{
        test_callbacks::TestCallbacks, test_filter::TestFilter, Filter, InputFilterThreadCreator,
    };
    use crate::input_filter_thread::InputFilterThread;
    use crate::slow_keys_filter::{SlowKeysFilter, POLICY_FLAG_DISABLE_KEY_REPEAT};
    use android_hardware_input_common::aidl::android::hardware::input::common::Source::Source;
    use binder::Strong;
    use com_android_server_inputflinger::aidl::com::android::server::inputflinger::{
        DeviceInfo::DeviceInfo, KeyEvent::KeyEvent, KeyEventAction::KeyEventAction,
    };
    use input::KeyboardType;
    use nix::{sys::time::TimeValLike, time::clock_gettime, time::ClockId};
    use std::sync::{Arc, RwLock};
    use std::time::Duration;

    static BASE_KEY_EVENT: KeyEvent = KeyEvent {
        id: 1,
        deviceId: 1,
        downTime: 0,
        readTime: 0,
        eventTime: 0,
        source: Source::KEYBOARD,
        displayId: 0,
        policyFlags: 0,
        action: KeyEventAction::DOWN,
        flags: 0,
        keyCode: 1,
        scanCode: 0,
        metaState: 0,
    };

    static SLOW_KEYS_THRESHOLD_NS: i64 = 100 * 1000000; // 100 ms

    #[test]
    fn test_is_notify_key_for_internal_non_alphabetic_keyboard_not_blocked() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_internal_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::NonAlphabetic,
        );

        let event = KeyEvent { action: KeyEventAction::DOWN, ..BASE_KEY_EVENT };
        filter.notify_key(&event);
        assert_eq!(next.last_event().unwrap(), event);
    }

    #[test]
    fn test_is_notify_key_for_external_stylus_not_blocked() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_external_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::NonAlphabetic,
        );

        let event =
            KeyEvent { action: KeyEventAction::DOWN, source: Source::STYLUS, ..BASE_KEY_EVENT };
        filter.notify_key(&event);
        assert_eq!(next.last_event().unwrap(), event);
    }

    #[test]
    fn test_notify_key_for_tv_remote_when_key_pressed_for_threshold_time() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_external_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::NonAlphabetic,
        );
        let down_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        let source = Source(Source::KEYBOARD.0 | Source::DPAD.0);
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::DOWN,
            downTime: down_time,
            eventTime: down_time,
            source,
            ..BASE_KEY_EVENT
        });
        assert!(next.last_event().is_none());

        std::thread::sleep(Duration::from_nanos(2 * SLOW_KEYS_THRESHOLD_NS as u64));
        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::DOWN,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                source,
                policyFlags: POLICY_FLAG_DISABLE_KEY_REPEAT,
                ..BASE_KEY_EVENT
            }
        );

        let up_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::UP,
            downTime: down_time,
            eventTime: up_time,
            source,
            ..BASE_KEY_EVENT
        });

        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::UP,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: up_time,
                source,
                ..BASE_KEY_EVENT
            }
        );
    }

    #[test]
    fn test_notify_key_for_internal_alphabetic_keyboard_when_key_pressed_for_threshold_time() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_internal_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::Alphabetic,
        );
        let down_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::DOWN,
            downTime: down_time,
            eventTime: down_time,
            ..BASE_KEY_EVENT
        });
        assert!(next.last_event().is_none());

        std::thread::sleep(Duration::from_nanos(2 * SLOW_KEYS_THRESHOLD_NS as u64));
        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::DOWN,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                policyFlags: POLICY_FLAG_DISABLE_KEY_REPEAT,
                ..BASE_KEY_EVENT
            }
        );

        let up_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::UP,
            downTime: down_time,
            eventTime: up_time,
            ..BASE_KEY_EVENT
        });

        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::UP,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: up_time,
                ..BASE_KEY_EVENT
            }
        );
    }

    #[test]
    fn test_notify_key_for_external_keyboard_when_key_pressed_for_threshold_time() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_external_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::Alphabetic,
        );
        let down_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::DOWN,
            downTime: down_time,
            eventTime: down_time,
            ..BASE_KEY_EVENT
        });
        assert!(next.last_event().is_none());

        std::thread::sleep(Duration::from_nanos(2 * SLOW_KEYS_THRESHOLD_NS as u64));
        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::DOWN,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                policyFlags: POLICY_FLAG_DISABLE_KEY_REPEAT,
                ..BASE_KEY_EVENT
            }
        );

        let up_time = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::UP,
            downTime: down_time,
            eventTime: up_time,
            ..BASE_KEY_EVENT
        });

        assert_eq!(
            next.last_event().unwrap(),
            KeyEvent {
                action: KeyEventAction::UP,
                downTime: down_time + SLOW_KEYS_THRESHOLD_NS,
                eventTime: up_time,
                ..BASE_KEY_EVENT
            }
        );
    }

    #[test]
    fn test_notify_key_for_external_keyboard_when_key_not_pressed_for_threshold_time() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_external_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::Alphabetic,
        );
        let mut now = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::DOWN,
            downTime: now,
            eventTime: now,
            ..BASE_KEY_EVENT
        });

        std::thread::sleep(Duration::from_nanos(SLOW_KEYS_THRESHOLD_NS as u64 / 2));

        now = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::UP,
            downTime: now,
            eventTime: now,
            ..BASE_KEY_EVENT
        });

        assert!(next.last_event().is_none());
    }

    #[test]
    fn test_notify_key_for_external_keyboard_when_device_removed_before_threshold_time() {
        let test_callbacks = TestCallbacks::new();
        let test_thread = get_thread(test_callbacks.clone());
        let next = TestFilter::new();
        let mut filter = setup_filter_with_external_device(
            Box::new(next.clone()),
            test_thread.clone(),
            1, /* device_id */
            SLOW_KEYS_THRESHOLD_NS,
            KeyboardType::Alphabetic,
        );

        let now = clock_gettime(ClockId::CLOCK_MONOTONIC).unwrap().num_nanoseconds();
        filter.notify_key(&KeyEvent {
            action: KeyEventAction::DOWN,
            downTime: now,
            eventTime: now,
            ..BASE_KEY_EVENT
        });

        filter.notify_devices_changed(&[]);
        std::thread::sleep(Duration::from_nanos(2 * SLOW_KEYS_THRESHOLD_NS as u64));

        assert!(next.last_event().is_none());
    }

    fn setup_filter_with_external_device(
        next: Box<dyn Filter + Send + Sync>,
        test_thread: InputFilterThread,
        device_id: i32,
        threshold: i64,
        keyboard_type: KeyboardType,
    ) -> SlowKeysFilter {
        setup_filter_with_devices(
            next,
            test_thread,
            &[DeviceInfo {
                deviceId: device_id,
                external: true,
                keyboardType: keyboard_type as i32,
            }],
            threshold,
        )
    }

    fn setup_filter_with_internal_device(
        next: Box<dyn Filter + Send + Sync>,
        test_thread: InputFilterThread,
        device_id: i32,
        threshold: i64,
        keyboard_type: KeyboardType,
    ) -> SlowKeysFilter {
        setup_filter_with_devices(
            next,
            test_thread,
            &[DeviceInfo {
                deviceId: device_id,
                external: false,
                keyboardType: keyboard_type as i32,
            }],
            threshold,
        )
    }

    fn setup_filter_with_devices(
        next: Box<dyn Filter + Send + Sync>,
        test_thread: InputFilterThread,
        devices: &[DeviceInfo],
        threshold: i64,
    ) -> SlowKeysFilter {
        let mut filter = SlowKeysFilter::new(next, threshold, test_thread);
        filter.notify_devices_changed(devices);
        filter
    }

    fn get_thread(callbacks: TestCallbacks) -> InputFilterThread {
        InputFilterThread::new(InputFilterThreadCreator::new(Arc::new(RwLock::new(Strong::new(
            Box::new(callbacks),
        )))))
    }
}
