check-flagged-apis: parse classes

Teach check-flagged-apis to parse classes, including inner classes.

Bug: 334870672
Test: atest --host check-flagged-apis-test
Change-Id: I17f65d3af55a20a1920b47f4c47fd0e92f9fa852
diff --git a/tools/check-flagged-apis/src/com/android/checkflaggedapis/CheckFlaggedApisTest.kt b/tools/check-flagged-apis/src/com/android/checkflaggedapis/CheckFlaggedApisTest.kt
index 5d87a4c..62c9cbb 100644
--- a/tools/check-flagged-apis/src/com/android/checkflaggedapis/CheckFlaggedApisTest.kt
+++ b/tools/check-flagged-apis/src/com/android/checkflaggedapis/CheckFlaggedApisTest.kt
@@ -16,6 +16,8 @@
 package com.android.checkflaggedapis
 
 import android.aconfig.Aconfig
+import android.aconfig.Aconfig.flag_state.DISABLED
+import android.aconfig.Aconfig.flag_state.ENABLED
 import java.io.ByteArrayInputStream
 import java.io.ByteArrayOutputStream
 import java.io.InputStream
@@ -28,10 +30,12 @@
     """
       // Signature format: 2.0
       package android {
-        public final class Clazz {
+        @FlaggedApi("android.flag.foo") public final class Clazz {
           ctor public Clazz();
           field @FlaggedApi("android.flag.foo") public static final int FOO = 1; // 0x1
         }
+        @FlaggedApi("android.flag.bar") public static class Clazz.Builder {
+        }
       }
 """
         .trim()
@@ -44,12 +48,17 @@
           <method name="&lt;init>()V"/>
           <field name="FOO"/>
         </class>
+        <class name="android/Clazz${"$"}Builder" since="2">
+        </class>
       </api>
 """
         .trim()
 
-private fun generateFlagsProto(fooState: Aconfig.flag_state): InputStream {
-  val parsed_flag =
+private fun generateFlagsProto(
+    fooState: Aconfig.flag_state,
+    barState: Aconfig.flag_state
+): InputStream {
+  val fooFlag =
       Aconfig.parsed_flag
           .newBuilder()
           .setPackage("android.flag")
@@ -57,9 +66,18 @@
           .setState(fooState)
           .setPermission(Aconfig.flag_permission.READ_ONLY)
           .build()
-  val parsed_flags = Aconfig.parsed_flags.newBuilder().addParsedFlag(parsed_flag).build()
+  val barFlag =
+      Aconfig.parsed_flag
+          .newBuilder()
+          .setPackage("android.flag")
+          .setName("bar")
+          .setState(barState)
+          .setPermission(Aconfig.flag_permission.READ_ONLY)
+          .build()
+  val flags =
+      Aconfig.parsed_flags.newBuilder().addParsedFlag(fooFlag).addParsedFlag(barFlag).build()
   val binaryProto = ByteArrayOutputStream()
-  parsed_flags.writeTo(binaryProto)
+  flags.writeTo(binaryProto)
   return ByteArrayInputStream(binaryProto.toByteArray())
 }
 
@@ -67,21 +85,32 @@
 class CheckFlaggedApisTest {
   @Test
   fun testParseApiSignature() {
-    val expected = setOf(Pair(Symbol("android.Clazz.FOO"), Flag("android.flag.foo")))
+    val expected =
+        setOf(
+            Pair(Symbol("android.Clazz"), Flag("android.flag.foo")),
+            Pair(Symbol("android.Clazz.FOO"), Flag("android.flag.foo")),
+            Pair(Symbol("android.Clazz.Builder"), Flag("android.flag.bar")),
+        )
     val actual = parseApiSignature("in-memory", API_SIGNATURE.byteInputStream())
     assertEquals(expected, actual)
   }
 
   @Test
   fun testParseFlagValues() {
-    val expected: Map<Flag, Boolean> = mapOf(Flag("android.flag.foo") to true)
-    val actual = parseFlagValues(generateFlagsProto(Aconfig.flag_state.ENABLED))
+    val expected: Map<Flag, Boolean> =
+        mapOf(Flag("android.flag.foo") to true, Flag("android.flag.bar") to true)
+    val actual = parseFlagValues(generateFlagsProto(ENABLED, ENABLED))
     assertEquals(expected, actual)
   }
 
   @Test
   fun testParseApiVersions() {
-    val expected: Set<Symbol> = setOf(Symbol("android.Clazz.FOO"))
+    val expected: Set<Symbol> =
+        setOf(
+            Symbol("android.Clazz"),
+            Symbol("android.Clazz.FOO"),
+            Symbol("android.Clazz.Builder"),
+        )
     val actual = parseApiVersions(API_VERSIONS.byteInputStream())
     assertEquals(expected, actual)
   }
@@ -92,7 +121,7 @@
     val actual =
         findErrors(
             parseApiSignature("in-memory", API_SIGNATURE.byteInputStream()),
-            parseFlagValues(generateFlagsProto(Aconfig.flag_state.ENABLED)),
+            parseFlagValues(generateFlagsProto(ENABLED, ENABLED)),
             parseApiVersions(API_VERSIONS.byteInputStream()))
     assertEquals(expected, actual)
   }
@@ -101,11 +130,15 @@
   fun testFindErrorsDisabledFlaggedApiIsPresent() {
     val expected =
         setOf<ApiError>(
-            DisabledFlaggedApiIsPresentError(Symbol("android.Clazz.FOO"), Flag("android.flag.foo")))
+            DisabledFlaggedApiIsPresentError(Symbol("android.Clazz"), Flag("android.flag.foo")),
+            DisabledFlaggedApiIsPresentError(Symbol("android.Clazz.FOO"), Flag("android.flag.foo")),
+            DisabledFlaggedApiIsPresentError(
+                Symbol("android.Clazz.Builder"), Flag("android.flag.bar")),
+        )
     val actual =
         findErrors(
             parseApiSignature("in-memory", API_SIGNATURE.byteInputStream()),
-            parseFlagValues(generateFlagsProto(Aconfig.flag_state.DISABLED)),
+            parseFlagValues(generateFlagsProto(DISABLED, DISABLED)),
             parseApiVersions(API_VERSIONS.byteInputStream()))
     assertEquals(expected, actual)
   }
diff --git a/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt b/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
index c4c5b11..918a5d9 100644
--- a/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
+++ b/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
@@ -19,7 +19,9 @@
 
 import android.aconfig.Aconfig
 import com.android.tools.metalava.model.BaseItemVisitor
+import com.android.tools.metalava.model.ClassItem
 import com.android.tools.metalava.model.FieldItem
+import com.android.tools.metalava.model.Item
 import com.android.tools.metalava.model.text.ApiFile
 import com.github.ajalt.clikt.core.CliktCommand
 import com.github.ajalt.clikt.core.ProgramResult
@@ -167,22 +169,31 @@
 }
 
 internal fun parseApiSignature(path: String, input: InputStream): Set<Pair<Symbol, Flag>> {
-  // TODO(334870672): add support for classes and metods
+  // TODO(334870672): add support for metods
   val output = mutableSetOf<Pair<Symbol, Flag>>()
   val visitor =
       object : BaseItemVisitor() {
-        override fun visitField(field: FieldItem) {
-          val flag =
-              field.modifiers
-                  .findAnnotation("android.annotation.FlaggedApi")
-                  ?.findAttribute("value")
-                  ?.value
-                  ?.value() as? String
-          if (flag != null) {
-            val symbol = Symbol.create(field.baselineElementId())
-            output.add(Pair(symbol, Flag(flag)))
+        override fun visitClass(cls: ClassItem) {
+          getFlagOrNull(cls)?.let { flag ->
+            val symbol = Symbol.create(cls.baselineElementId())
+            output.add(Pair(symbol, flag))
           }
         }
+
+        override fun visitField(field: FieldItem) {
+          getFlagOrNull(field)?.let { flag ->
+            val symbol = Symbol.create(field.baselineElementId())
+            output.add(Pair(symbol, flag))
+          }
+        }
+
+        private fun getFlagOrNull(item: Item): Flag? {
+          return item.modifiers
+              .findAnnotation("android.annotation.FlaggedApi")
+              ?.findAttribute("value")
+              ?.value
+              ?.let { Flag(it.value() as String) }
+        }
       }
   val codebase = ApiFile.parseApi(path, input)
   codebase.accept(visitor)
@@ -203,6 +214,18 @@
   val factory = DocumentBuilderFactory.newInstance()
   val parser = factory.newDocumentBuilder()
   val document = parser.parse(input)
+
+  val classes = document.getElementsByTagName("class")
+  // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
+  for (i in 0.rangeUntil(classes.getLength())) {
+    val cls = classes.item(i)
+    val className =
+        requireNotNull(cls.getAttribute("name")) {
+          "Bad XML: <class> element without name attribute"
+        }
+    output.add(Symbol.create(className))
+  }
+
   val fields = document.getElementsByTagName("field")
   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
   for (i in 0.rangeUntil(fields.getLength())) {
@@ -216,6 +239,7 @@
             .getAttribute("name")
     output.add(Symbol.create("$className.$fieldName"))
   }
+
   return output
 }