Merge "aconfig: Rust codegen 2nd iteration" into main
diff --git a/tools/aconfig/src/codegen_rust.rs b/tools/aconfig/src/codegen_rust.rs
index f931418..45303ce 100644
--- a/tools/aconfig/src/codegen_rust.rs
+++ b/tools/aconfig/src/codegen_rust.rs
@@ -19,10 +19,14 @@
 use tinytemplate::TinyTemplate;
 
 use crate::codegen;
-use crate::commands::OutputFile;
+use crate::commands::{CodegenMode, OutputFile};
 use crate::protos::{ProtoFlagPermission, ProtoFlagState, ProtoParsedFlag};
 
-pub fn generate_rust_code<'a, I>(package: &str, parsed_flags_iter: I) -> Result<OutputFile>
+pub fn generate_rust_code<'a, I>(
+    package: &str,
+    parsed_flags_iter: I,
+    codegen_mode: CodegenMode,
+) -> Result<OutputFile>
 where
     I: Iterator<Item = &'a ProtoParsedFlag>,
 {
@@ -34,7 +38,13 @@
         modules: package.split('.').map(|s| s.to_string()).collect::<Vec<_>>(),
     };
     let mut template = TinyTemplate::new();
-    template.add_template("rust_code_gen", include_str!("../templates/rust.template"))?;
+    template.add_template(
+        "rust_code_gen",
+        match codegen_mode {
+            CodegenMode::Production => include_str!("../templates/rust_prod.template"),
+            CodegenMode::Test => include_str!("../templates/rust_test.template"),
+        },
+    )?;
     let contents = template.render("rust_code_gen", &context)?;
     let path = ["src", "lib.rs"].iter().collect();
     Ok(OutputFile { contents: contents.into(), path })
@@ -49,41 +59,27 @@
 
 #[derive(Serialize)]
 struct TemplateParsedFlag {
+    pub readwrite: bool,
+    pub default_value: String,
     pub name: String,
     pub device_config_namespace: String,
     pub device_config_flag: String,
-
-    // TinyTemplate's conditionals are limited to single <bool> expressions; list all options here
-    // Invariant: exactly one of these fields will be true
-    pub is_read_only_enabled: bool,
-    pub is_read_only_disabled: bool,
-    pub is_read_write: bool,
 }
 
 impl TemplateParsedFlag {
     #[allow(clippy::nonminimal_bool)]
     fn new(package: &str, pf: &ProtoParsedFlag) -> Self {
         let template = TemplateParsedFlag {
+            readwrite: pf.permission() == ProtoFlagPermission::READ_WRITE,
+            default_value: match pf.state() {
+                ProtoFlagState::ENABLED => "true".to_string(),
+                ProtoFlagState::DISABLED => "false".to_string(),
+            },
             name: pf.name().to_string(),
             device_config_namespace: pf.namespace().to_string(),
             device_config_flag: codegen::create_device_config_ident(package, pf.name())
                 .expect("values checked at flag parse time"),
-            is_read_only_enabled: pf.permission() == ProtoFlagPermission::READ_ONLY
-                && pf.state() == ProtoFlagState::ENABLED,
-            is_read_only_disabled: pf.permission() == ProtoFlagPermission::READ_ONLY
-                && pf.state() == ProtoFlagState::DISABLED,
-            is_read_write: pf.permission() == ProtoFlagPermission::READ_WRITE,
         };
-        #[rustfmt::skip]
-        debug_assert!(
-            (template.is_read_only_enabled && !template.is_read_only_disabled && !template.is_read_write) ||
-            (!template.is_read_only_enabled && template.is_read_only_disabled && !template.is_read_write) ||
-            (!template.is_read_only_enabled && !template.is_read_only_disabled && template.is_read_write),
-            "TemplateParsedFlag invariant failed: {} {} {}",
-            template.is_read_only_enabled,
-            template.is_read_only_disabled,
-            template.is_read_write,
-        );
         template
     }
 }
@@ -92,48 +88,224 @@
 mod tests {
     use super::*;
 
-    #[test]
-    fn test_generate_rust_code() {
-        let parsed_flags = crate::test::parse_test_flags();
-        let generated =
-            generate_rust_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter()).unwrap();
-        assert_eq!("src/lib.rs", format!("{}", generated.path.display()));
-        let expected = r#"
-pub mod com {
-pub mod android {
-pub mod aconfig {
-pub mod test {
+    const PROD_EXPECTED: &str = r#"
+//! codegenerated rust flag lib
+
+/// flag provider
+pub struct FlagProvider
+
+impl FlagProvider {
+    /// query flag disabled_ro
+    pub fn disabled_ro(&self) -> bool {
+        false
+    }
+
+    /// query flag disabled_rw
+    pub fn disabled_rw(&self) -> bool {
+        flags_rust::GetServerConfigurableFlag(
+            "aconfig_test",
+            "com.android.aconfig.test.disabled_rw",
+            "false") == "true"
+    }
+
+    /// query flag enabled_ro
+    pub fn enabled_ro(&self) -> bool {
+        true
+    }
+
+    /// query flag enabled_rw
+    pub fn enabled_rw(&self) -> bool {
+        flags_rust::GetServerConfigurableFlag(
+            "aconfig_test",
+            "com.android.aconfig.test.enabled_rw",
+            "true") == "true"
+    }
+}
+
+/// flag provider
+pub static PROVIDER: FlagProvider = FlagProvider;
+
+/// query flag disabled_ro
 #[inline(always)]
-pub const fn r#disabled_ro() -> bool {
+pub fn disabled_ro() -> bool {
     false
 }
 
+/// query flag disabled_rw
 #[inline(always)]
-pub fn r#disabled_rw() -> bool {
-    flags_rust::GetServerConfigurableFlag("aconfig_test", "com.android.aconfig.test.disabled_rw", "false") == "true"
+pub fn disabled_rw() -> bool {
+    PROVIDER.disabled_rw()
 }
 
+/// query flag enabled_ro
 #[inline(always)]
-pub const fn r#enabled_ro() -> bool {
+pub fn enabled_ro() -> bool {
     true
 }
 
+/// query flag enabled_rw
 #[inline(always)]
-pub fn r#enabled_rw() -> bool {
-    flags_rust::GetServerConfigurableFlag("aconfig_test", "com.android.aconfig.test.enabled_rw", "false") == "true"
-}
-
-}
-}
-}
+pub fn enabled_rw() -> bool {
+    PROVIDER.enabled_rw()
 }
 "#;
+
+    const TEST_EXPECTED: &str = r#"
+//! codegenerated rust flag lib
+
+use std::collections::BTreeMap;
+use std::sync::Mutex;
+
+/// flag provider
+pub struct FlagProvider {
+    overrides: BTreeMap<&'static str, bool>,
+}
+
+impl FlagProvider {
+    /// query flag disabled_ro
+    pub fn disabled_ro(&self) -> bool {
+        self.overrides.get("disabled_ro").copied().unwrap_or(
+            false
+        )
+    }
+
+    /// set flag disabled_ro
+    pub fn set_disabled_ro(&mut self, val: bool) {
+        self.overrides.insert("disabled_ro", val);
+    }
+
+    /// query flag disabled_rw
+    pub fn disabled_rw(&self) -> bool {
+        self.overrides.get("disabled_rw").copied().unwrap_or(
+            flags_rust::GetServerConfigurableFlag(
+                "aconfig_test",
+                "com.android.aconfig.test.disabled_rw",
+                "false") == "true"
+        )
+    }
+
+    /// set flag disabled_rw
+    pub fn set_disabled_rw(&mut self, val: bool) {
+        self.overrides.insert("disabled_rw", val);
+    }
+
+    /// query flag enabled_ro
+    pub fn enabled_ro(&self) -> bool {
+        self.overrides.get("enabled_ro").copied().unwrap_or(
+            true
+        )
+    }
+
+    /// set flag enabled_ro
+    pub fn set_enabled_ro(&mut self, val: bool) {
+        self.overrides.insert("enabled_ro", val);
+    }
+
+    /// query flag enabled_rw
+    pub fn enabled_rw(&self) -> bool {
+        self.overrides.get("enabled_rw").copied().unwrap_or(
+            flags_rust::GetServerConfigurableFlag(
+                "aconfig_test",
+                "com.android.aconfig.test.enabled_rw",
+                "true") == "true"
+        )
+    }
+
+    /// set flag enabled_rw
+    pub fn set_enabled_rw(&mut self, val: bool) {
+        self.overrides.insert("enabled_rw", val);
+    }
+
+    /// clear all flag overrides
+    pub fn reset_flags(&mut self) {
+        self.overrides.clear();
+    }
+}
+
+/// flag provider
+pub static PROVIDER: Mutex<FlagProvider> = Mutex::new(
+    FlagProvider {overrides: BTreeMap::new()}
+);
+
+/// query flag disabled_ro
+#[inline(always)]
+pub fn disabled_ro() -> bool {
+    PROVIDER.lock().unwrap().disabled_ro()
+}
+
+/// set flag disabled_ro
+#[inline(always)]
+pub fn set_disabled_ro(val: bool) {
+    PROVIDER.lock().unwrap().set_disabled_ro(val);
+}
+
+/// query flag disabled_rw
+#[inline(always)]
+pub fn disabled_rw() -> bool {
+    PROVIDER.lock().unwrap().disabled_rw()
+}
+
+/// set flag disabled_rw
+#[inline(always)]
+pub fn set_disabled_rw(val: bool) {
+    PROVIDER.lock().unwrap().set_disabled_rw(val);
+}
+
+/// query flag enabled_ro
+#[inline(always)]
+pub fn enabled_ro() -> bool {
+    PROVIDER.lock().unwrap().enabled_ro()
+}
+
+/// set flag enabled_ro
+#[inline(always)]
+pub fn set_enabled_ro(val: bool) {
+    PROVIDER.lock().unwrap().set_enabled_ro(val);
+}
+
+/// query flag enabled_rw
+#[inline(always)]
+pub fn enabled_rw() -> bool {
+    PROVIDER.lock().unwrap().enabled_rw()
+}
+
+/// set flag enabled_rw
+#[inline(always)]
+pub fn set_enabled_rw(val: bool) {
+    PROVIDER.lock().unwrap().set_enabled_rw(val);
+}
+
+/// clear all flag override
+pub fn reset_flags() {
+    PROVIDER.lock().unwrap().reset_flags()
+}
+"#;
+
+    fn test_generate_rust_code(mode: CodegenMode) {
+        let parsed_flags = crate::test::parse_test_flags();
+        let generated =
+            generate_rust_code(crate::test::TEST_PACKAGE, parsed_flags.parsed_flag.iter(), mode)
+                .unwrap();
+        assert_eq!("src/lib.rs", format!("{}", generated.path.display()));
         assert_eq!(
             None,
             crate::test::first_significant_code_diff(
-                expected,
+                match mode {
+                    CodegenMode::Production => PROD_EXPECTED,
+                    CodegenMode::Test => TEST_EXPECTED,
+                },
                 &String::from_utf8(generated.contents).unwrap()
             )
         );
     }
+
+    #[test]
+    fn test_generate_rust_code_for_prod() {
+        test_generate_rust_code(CodegenMode::Production);
+    }
+
+    #[test]
+    fn test_generate_rust_code_for_test() {
+        test_generate_rust_code(CodegenMode::Test);
+    }
 }
diff --git a/tools/aconfig/src/commands.rs b/tools/aconfig/src/commands.rs
index 687f319..0ac84b2 100644
--- a/tools/aconfig/src/commands.rs
+++ b/tools/aconfig/src/commands.rs
@@ -108,7 +108,11 @@
             crate::protos::flag_value::verify_fields(&flag_value)
                 .with_context(|| format!("Failed to parse {}", input.source))?;
 
-            let Some(parsed_flag) = parsed_flags.parsed_flag.iter_mut().find(|pf| pf.package() == flag_value.package() && pf.name() == flag_value.name()) else {
+            let Some(parsed_flag) = parsed_flags
+                .parsed_flag
+                .iter_mut()
+                .find(|pf| pf.package() == flag_value.package() && pf.name() == flag_value.name())
+            else {
                 // (silently) skip unknown flags
                 continue;
             };
@@ -151,12 +155,12 @@
     generate_cpp_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
 }
 
-pub fn create_rust_lib(mut input: Input) -> Result<OutputFile> {
+pub fn create_rust_lib(mut input: Input, codegen_mode: CodegenMode) -> Result<OutputFile> {
     let parsed_flags = input.try_parse_flags()?;
     let Some(package) = find_unique_package(&parsed_flags) else {
         bail!("no parsed flags, or the parsed flags use different packages");
     };
-    generate_rust_code(package, parsed_flags.parsed_flag.iter())
+    generate_rust_code(package, parsed_flags.parsed_flag.iter(), codegen_mode)
 }
 
 pub fn create_device_config_defaults(mut input: Input) -> Result<Vec<u8>> {
diff --git a/tools/aconfig/src/main.rs b/tools/aconfig/src/main.rs
index 72feb94..151cbe8 100644
--- a/tools/aconfig/src/main.rs
+++ b/tools/aconfig/src/main.rs
@@ -71,7 +71,13 @@
         .subcommand(
             Command::new("create-rust-lib")
                 .arg(Arg::new("cache").long("cache").required(true))
-                .arg(Arg::new("out").long("out").required(true)),
+                .arg(Arg::new("out").long("out").required(true))
+                .arg(
+                    Arg::new("mode")
+                        .long("mode")
+                        .value_parser(EnumValueParser::<commands::CodegenMode>::new())
+                        .default_value("production"),
+                ),
         )
         .subcommand(
             Command::new("create-device-config-defaults")
@@ -178,7 +184,8 @@
         }
         Some(("create-rust-lib", sub_matches)) => {
             let cache = open_single_file(sub_matches, "cache")?;
-            let generated_file = commands::create_rust_lib(cache)?;
+            let mode = get_required_arg::<CodegenMode>(sub_matches, "mode")?;
+            let generated_file = commands::create_rust_lib(cache, *mode)?;
             let dir = PathBuf::from(get_required_arg::<String>(sub_matches, "out")?);
             write_output_file_realtive_to_dir(&dir, &generated_file)?;
         }
diff --git a/tools/aconfig/templates/rust.template b/tools/aconfig/templates/rust.template
deleted file mode 100644
index 960c494..0000000
--- a/tools/aconfig/templates/rust.template
+++ /dev/null
@@ -1,29 +0,0 @@
-{{- for mod in modules -}}
-pub mod {mod} \{
-{{ endfor -}}
-{{- for flag in template_flags -}}
-{{- if flag.is_read_only_disabled -}}
-#[inline(always)]
-pub const fn r#{flag.name}() -> bool \{
-    false
-}
-
-{{ endif -}}
-{{- if flag.is_read_only_enabled -}}
-#[inline(always)]
-pub const fn r#{flag.name}() -> bool \{
-    true
-}
-
-{{ endif -}}
-{{- if flag.is_read_write -}}
-#[inline(always)]
-pub fn r#{flag.name}() -> bool \{
-    flags_rust::GetServerConfigurableFlag("{flag.device_config_namespace}", "{flag.device_config_flag}", "false") == "true"
-}
-
-{{ endif -}}
-{{- endfor -}}
-{{- for mod in modules -}}
-}
-{{ endfor -}}
diff --git a/tools/aconfig/templates/rust_prod.template b/tools/aconfig/templates/rust_prod.template
new file mode 100644
index 0000000..543107e
--- /dev/null
+++ b/tools/aconfig/templates/rust_prod.template
@@ -0,0 +1,38 @@
+//! codegenerated rust flag lib
+
+/// flag provider
+pub struct FlagProvider
+
+impl FlagProvider \{
+
+    {{ for flag in template_flags }}
+    /// query flag {flag.name}
+    pub fn {flag.name}(&self) -> bool \{
+    {{ if flag.readwrite -}}
+        flags_rust::GetServerConfigurableFlag(
+          "{flag.device_config_namespace}",
+          "{flag.device_config_flag}",
+          "{flag.default_value}") == "true"
+    {{ -else- }}
+        {flag.default_value}
+    {{ -endif }}
+    }
+    {{ endfor }}
+
+}
+
+/// flag provider
+pub static PROVIDER: FlagProvider = FlagProvider;
+
+{{ for flag in template_flags }}
+/// query flag {flag.name}
+#[inline(always)]
+{{ if flag.readwrite -}}
+pub fn {flag.name}() -> bool \{
+    PROVIDER.{flag.name}()
+{{ -else- }}
+pub fn {flag.name}() -> bool \{
+    {flag.default_value}
+{{ -endif }}
+}
+{{ endfor }}
diff --git a/tools/aconfig/templates/rust_test.template b/tools/aconfig/templates/rust_test.template
new file mode 100644
index 0000000..1e2c28a
--- /dev/null
+++ b/tools/aconfig/templates/rust_test.template
@@ -0,0 +1,61 @@
+//! codegenerated rust flag lib
+
+use std::collections::BTreeMap;
+use std::sync::Mutex;
+
+/// flag provider
+pub struct FlagProvider \{
+    overrides: BTreeMap<&'static str, bool>,
+}
+
+impl FlagProvider \{
+    {{ for flag in template_flags }}
+    /// query flag {flag.name}
+    pub fn {flag.name}(&self) -> bool \{
+        self.overrides.get("{flag.name}").copied().unwrap_or(
+        {{ if flag.readwrite -}}
+          flags_rust::GetServerConfigurableFlag(
+            "{flag.device_config_namespace}",
+            "{flag.device_config_flag}",
+            "{flag.default_value}") == "true"
+        {{ -else- }}
+           {flag.default_value}
+        {{ -endif }}
+        )
+    }
+
+    /// set flag {flag.name}
+    pub fn set_{flag.name}(&mut self, val: bool) \{
+        self.overrides.insert("{flag.name}", val);
+    }
+    {{ endfor }}
+
+    /// clear all flag overrides
+    pub fn reset_flags(&mut self) \{
+        self.overrides.clear();
+    }
+}
+
+/// flag provider
+pub static PROVIDER: Mutex<FlagProvider> = Mutex::new(
+    FlagProvider \{overrides: BTreeMap::new()}
+);
+
+{{ for flag in template_flags }}
+/// query flag {flag.name}
+#[inline(always)]
+pub fn {flag.name}() -> bool \{
+    PROVIDER.lock().unwrap().{flag.name}()
+}
+
+/// set flag {flag.name}
+#[inline(always)]
+pub fn set_{flag.name}(val: bool) \{
+    PROVIDER.lock().unwrap().set_{flag.name}(val);
+}
+{{ endfor }}
+
+/// clear all flag override
+pub fn reset_flags() \{
+    PROVIDER.lock().unwrap().reset_flags()
+}