check-flagged-apis: parse classes

Teach check-flagged-apis to parse classes, including inner classes.

Bug: 334870672
Test: atest --host check-flagged-apis-test
Change-Id: I17f65d3af55a20a1920b47f4c47fd0e92f9fa852
diff --git a/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt b/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
index c4c5b11..918a5d9 100644
--- a/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
+++ b/tools/check-flagged-apis/src/com/android/checkflaggedapis/Main.kt
@@ -19,7 +19,9 @@
 
 import android.aconfig.Aconfig
 import com.android.tools.metalava.model.BaseItemVisitor
+import com.android.tools.metalava.model.ClassItem
 import com.android.tools.metalava.model.FieldItem
+import com.android.tools.metalava.model.Item
 import com.android.tools.metalava.model.text.ApiFile
 import com.github.ajalt.clikt.core.CliktCommand
 import com.github.ajalt.clikt.core.ProgramResult
@@ -167,22 +169,31 @@
 }
 
 internal fun parseApiSignature(path: String, input: InputStream): Set<Pair<Symbol, Flag>> {
-  // TODO(334870672): add support for classes and metods
+  // TODO(334870672): add support for metods
   val output = mutableSetOf<Pair<Symbol, Flag>>()
   val visitor =
       object : BaseItemVisitor() {
-        override fun visitField(field: FieldItem) {
-          val flag =
-              field.modifiers
-                  .findAnnotation("android.annotation.FlaggedApi")
-                  ?.findAttribute("value")
-                  ?.value
-                  ?.value() as? String
-          if (flag != null) {
-            val symbol = Symbol.create(field.baselineElementId())
-            output.add(Pair(symbol, Flag(flag)))
+        override fun visitClass(cls: ClassItem) {
+          getFlagOrNull(cls)?.let { flag ->
+            val symbol = Symbol.create(cls.baselineElementId())
+            output.add(Pair(symbol, flag))
           }
         }
+
+        override fun visitField(field: FieldItem) {
+          getFlagOrNull(field)?.let { flag ->
+            val symbol = Symbol.create(field.baselineElementId())
+            output.add(Pair(symbol, flag))
+          }
+        }
+
+        private fun getFlagOrNull(item: Item): Flag? {
+          return item.modifiers
+              .findAnnotation("android.annotation.FlaggedApi")
+              ?.findAttribute("value")
+              ?.value
+              ?.let { Flag(it.value() as String) }
+        }
       }
   val codebase = ApiFile.parseApi(path, input)
   codebase.accept(visitor)
@@ -203,6 +214,18 @@
   val factory = DocumentBuilderFactory.newInstance()
   val parser = factory.newDocumentBuilder()
   val document = parser.parse(input)
+
+  val classes = document.getElementsByTagName("class")
+  // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
+  for (i in 0.rangeUntil(classes.getLength())) {
+    val cls = classes.item(i)
+    val className =
+        requireNotNull(cls.getAttribute("name")) {
+          "Bad XML: <class> element without name attribute"
+        }
+    output.add(Symbol.create(className))
+  }
+
   val fields = document.getElementsByTagName("field")
   // ktfmt doesn't understand the `..<` range syntax; explicitly call .rangeUntil instead
   for (i in 0.rangeUntil(fields.getLength())) {
@@ -216,6 +239,7 @@
             .getAttribute("name")
     output.add(Symbol.create("$className.$fieldName"))
   }
+
   return output
 }