Merge "Speed up SniffMidi() by parsing the initial bytes of the header" into main
diff --git a/media/module/extractors/Android.bp b/media/module/extractors/Android.bp
index 8f9fbab..e29d3e6 100644
--- a/media/module/extractors/Android.bp
+++ b/media/module/extractors/Android.bp
@@ -73,3 +73,21 @@
         ],
     },
 }
+
+aconfig_declarations {
+    name: "android.media.extractor.flags-aconfig",
+    package: "com.android.media.extractor.flags",
+    container: "com.android.media",
+    srcs: ["extractor.aconfig"],
+}
+
+cc_aconfig_library {
+    name: "android.media.extractor.flags-aconfig-cc",
+    aconfig_declarations: "android.media.extractor.flags-aconfig",
+    host_supported: true,
+    min_sdk_version: "29",
+    apex_available: [
+        "//apex_available:platform",
+        "com.android.media",
+    ],
+}
diff --git a/media/module/extractors/extractor.aconfig b/media/module/extractors/extractor.aconfig
new file mode 100644
index 0000000..c61efa4
--- /dev/null
+++ b/media/module/extractors/extractor.aconfig
@@ -0,0 +1,13 @@
+# Media Extractor flags.
+#
+# !!! Please add flags in alphabetical order. !!!
+package: "com.android.media.extractor.flags"
+container: "com.android.media"
+
+flag {
+    name: "extractor_sniff_midi_optimizations"
+    is_exported: true
+    namespace: "media_extractor"
+    description: "Enable SniffMidi optimizations."
+    bug: "359920208"
+}
diff --git a/media/module/extractors/fuzzers/Android.bp b/media/module/extractors/fuzzers/Android.bp
index 9ed768a..3da1589 100644
--- a/media/module/extractors/fuzzers/Android.bp
+++ b/media/module/extractors/fuzzers/Android.bp
@@ -24,6 +24,7 @@
     // to get the below license kinds:
     //   SPDX-license-identifier-Apache-2.0
     default_applicable_licenses: ["frameworks_av_license"],
+    default_team: "trendy_team_android_media_solutions_playback",
 }
 
 cc_defaults {
@@ -302,12 +303,18 @@
     ],
 
     static_libs: [
+        "android.media.extractor.flags-aconfig-cc",
+        "libaconfig_storage_read_api_cc",
         "libsonivox",
         "libmedia_midiiowrapper",
         "libmidiextractor",
         "libwatchdog",
     ],
 
+    shared_libs: [
+        "server_configurable_flags",
+    ],
+
     dictionary: "midi_extractor_fuzzer.dict",
 
     host_supported: true,
diff --git a/media/module/extractors/midi/Android.bp b/media/module/extractors/midi/Android.bp
index feabf9e..0eb34fc 100644
--- a/media/module/extractors/midi/Android.bp
+++ b/media/module/extractors/midi/Android.bp
@@ -32,6 +32,8 @@
     ],
 
     static_libs: [
+        "android.media.extractor.flags-aconfig-cc",
+        "libaconfig_storage_read_api_cc",
         "libmedia_midiiowrapper",
         "libsonivoxwithoutjet",
         "libstagefright_foundation",
@@ -40,6 +42,7 @@
 
     shared_libs: [
         "libbase",
+        "server_configurable_flags",
     ],
 
     host_supported: true,
diff --git a/media/module/extractors/midi/MidiExtractor.cpp b/media/module/extractors/midi/MidiExtractor.cpp
index 167cc40..98d7716 100644
--- a/media/module/extractors/midi/MidiExtractor.cpp
+++ b/media/module/extractors/midi/MidiExtractor.cpp
@@ -20,6 +20,7 @@
 
 #include "MidiExtractor.h"
 
+#include <com_android_media_extractor_flags.h>
 #include <media/MidiIoWrapper.h>
 #include <media/stagefright/foundation/ADebug.h>
 #include <media/stagefright/MediaBufferGroup.h>
@@ -323,10 +324,97 @@
     return AMediaFormat_copy(meta, mFileMetadata);
 }
 
-// Sniffer
+static bool startsWith(const uint8_t *buf, size_t size, const char *pattern, size_t patternSize) {
+    if (size < patternSize) {
+        return false;
+    }
+    return (memcmp(buf, pattern, patternSize) == 0);
+}
 
-bool SniffMidi(CDataSource *source, float *confidence)
-{
+static bool isValidMThd(const uint8_t *buf, size_t size) {
+    return startsWith(buf, size, "MThd", 4);
+}
+
+static bool isValidXmf(const uint8_t *buf, size_t size) {
+    return startsWith(buf, size, "XMF_", 4);
+}
+
+static bool isValidImelody(const uint8_t *buf, size_t size) {
+    return startsWith(buf, size, "BEGIN:IMELODY", 13);
+}
+
+static bool isValidRtttl(const uint8_t *buf, size_t size) {
+    #define RTTTL_MAX_TITLE_LEN 32
+    // rtttl starts with the following:
+    // <title>:<type>=<value>
+    //
+    // Where:
+    // - <title>: Up to 32 characters
+    // - <type>: Single character indicating:
+    //     'd' for duration
+    //     'o' for octave
+    //     'b' for beats per minute
+    // - <value>: Corresponding value for the type
+    if (size < 4) {
+        return false;
+    }
+    for (size_t i = 0; i < RTTTL_MAX_TITLE_LEN && i < size; i++) {
+        if (buf[i] == ':') {
+            if (i < (size - 3) && buf[i + 2] == '=') {
+                return true;
+            }
+            break;
+        }
+    }
+    return false;
+}
+
+static bool isValidOta(const uint8_t *buf, size_t size) {
+    #define OTA_RINGTONE 0x25
+    #define OTA_SOUND 0x1d
+    #define OTA_UNICODE 0x22
+
+    // ota starts with the following:
+    // <cmdLen><cmd1><cmd2>..<cmdN>
+    //
+    // Where:
+    // - <cmdLen>: Single character command length
+    // - <cmd1>: Single character (OTA_RINGTONE << 1)
+    // - <cmd2>: Single character (OTA_SOUND << 1) or (OTA_UNICODE << 1)
+    //           and so on with last cmd being (0x1d << 1)
+
+    if (size < 3) {
+        return false;
+    }
+
+    uint8_t cmdLen = buf[0];
+    if (cmdLen < 2) {
+        return false;
+    }
+
+    if ((buf[1] >> 1) != OTA_RINGTONE) {
+        return false;
+    }
+    cmdLen--;
+
+    size_t i = 2;
+    while(cmdLen && i < size) {
+        switch(buf[i] >> 1) {
+            case OTA_SOUND:
+                return true;
+            case OTA_UNICODE:
+                break;
+            default:
+                return false;
+        }
+        cmdLen--;
+        i++;
+    }
+
+    return false;
+}
+
+bool SniffMidiLegacy(CDataSource *source, float *confidence) {
     MidiEngine p(source, NULL, NULL);
     if (p.initCheck() == OK) {
         *confidence = 0.8;
@@ -335,7 +423,47 @@
     }
     ALOGV("SniffMidi: no");
     return false;
+}
 
+bool SniffMidiEfficiently(CDataSource *source, float *confidence) {
+    uint8_t header[128];
+    int filled = source->readAt(source->handle, 0, header, sizeof(header));
+
+    if (isValidMThd(header, filled)) {
+        *confidence = 0.80;
+        ALOGV("SniffMidi: yes, MThd");
+        return true;
+    }
+    if (isValidXmf(header, filled)) {
+        *confidence = 0.80;
+        ALOGV("SniffMidi: yes, XMF_");
+        return true;
+    }
+    if (isValidImelody(header, filled)) {
+        *confidence = 0.80;
+        ALOGV("SniffMidi: yes, imelody");
+        return true;
+    }
+    if (isValidRtttl(header, filled)) {
+        *confidence = 0.80;
+        ALOGV("SniffMidi: yes, rtttl");
+        return true;
+    }
+    if (isValidOta(header, filled)) {
+        *confidence = 0.80;
+        ALOGV("SniffMidi: yes, ota");
+        return true;
+    }
+    ALOGV("SniffMidi: no");
+    return false;
+}
+
+// Sniffer
+bool SniffMidi(CDataSource *source, float *confidence) {
+    if(com::android::media::extractor::flags::extractor_sniff_midi_optimizations()) {
+        return SniffMidiEfficiently(source, confidence);
+    }
+    return SniffMidiLegacy(source, confidence);
 }
 
 static const char *extensions[] = {