More HTTP live support, AES encryption etc.

Change-Id: Ia5088042dd0a2181cb73cf8c7a2ff81e34b3064c
related-to-bug: 2368598
diff --git a/media/libstagefright/Android.mk b/media/libstagefright/Android.mk
index 1df0bed..8fe1d4d 100644
--- a/media/libstagefright/Android.mk
+++ b/media/libstagefright/Android.mk
@@ -68,7 +68,8 @@
         libsurfaceflinger_client \
         libstagefright_yuv \
         libcamera_client \
-        libdrmframework
+        libdrmframework  \
+        libcrypto
 
 LOCAL_STATIC_LIBRARIES := \
         libstagefright_aacdec \
diff --git a/media/libstagefright/AwesomePlayer.cpp b/media/libstagefright/AwesomePlayer.cpp
index f084e28..41f5f30 100644
--- a/media/libstagefright/AwesomePlayer.cpp
+++ b/media/libstagefright/AwesomePlayer.cpp
@@ -49,6 +49,8 @@
 
 #include <media/stagefright/foundation/ALooper.h>
 
+#define USE_SURFACE_ALLOC 1
+
 namespace android {
 
 static int64_t kLowWaterMarkUs = 2000000ll;  // 2secs
@@ -294,6 +296,16 @@
 
     mUri = uri;
 
+    if (!strncmp("http://", uri, 7)) {
+        // Hack to support http live.
+
+        size_t len = strlen(uri);
+        if (!strcasecmp(&uri[len - 5], ".m3u8")) {
+            mUri = "httplive://";
+            mUri.append(&uri[7]);
+        }
+    }
+
     if (headers) {
         mUriHeaders = *headers;
     }
@@ -873,7 +885,7 @@
         IPCThreadState::self()->flushCommands();
 
         if (mSurface != NULL) {
-            if (strncmp(component, "OMX.", 4) == 0) {
+            if (USE_SURFACE_ALLOC && strncmp(component, "OMX.", 4) == 0) {
                 // Hardware decoders avoid the CPU color conversion by decoding
                 // directly to ANativeBuffers, so we must use a renderer that
                 // just pushes those buffers to the ANativeWindow.
@@ -1143,7 +1155,7 @@
             mClient.interface(), mVideoTrack->getFormat(),
             false, // createEncoder
             mVideoTrack,
-            NULL, flags, mSurface);
+            NULL, flags, USE_SURFACE_ALLOC ? mSurface : NULL);
 
     if (mVideoSource != NULL) {
         int64_t durationUs;
diff --git a/media/libstagefright/httplive/Android.mk b/media/libstagefright/httplive/Android.mk
index cc7dd4f..3aabf5f 100644
--- a/media/libstagefright/httplive/Android.mk
+++ b/media/libstagefright/httplive/Android.mk
@@ -9,7 +9,8 @@
 LOCAL_C_INCLUDES:= \
 	$(JNI_H_INCLUDE) \
 	$(TOP)/frameworks/base/include/media/stagefright/openmax \
-        $(TOP)/frameworks/base/media/libstagefright
+        $(TOP)/frameworks/base/media/libstagefright \
+        $(TOP)/external/openssl/include
 
 LOCAL_MODULE:= libstagefright_httplive
 
diff --git a/media/libstagefright/httplive/LiveSource.cpp b/media/libstagefright/httplive/LiveSource.cpp
index 4124571..39e3e75 100644
--- a/media/libstagefright/httplive/LiveSource.cpp
+++ b/media/libstagefright/httplive/LiveSource.cpp
@@ -22,9 +22,14 @@
 #include "include/M3UParser.h"
 #include "include/NuHTTPDataSource.h"
 
+#include <cutils/properties.h>
+#include <media/stagefright/foundation/hexdump.h>
 #include <media/stagefright/foundation/ABuffer.h>
+#include <media/stagefright/foundation/ADebug.h>
 #include <media/stagefright/FileSource.h>
-#include <media/stagefright/MediaDebug.h>
+
+#include <ctype.h>
+#include <openssl/aes.h>
 
 namespace android {
 
@@ -38,7 +43,9 @@
       mSourceSize(0),
       mOffsetBias(0),
       mSignalDiscontinuity(false),
-      mPrevBandwidthIndex(-1) {
+      mPrevBandwidthIndex(-1),
+      mAESKey((AES_KEY *)malloc(sizeof(AES_KEY))),
+      mStreamEncrypted(false) {
     if (switchToNext()) {
         mInitCheck = OK;
 
@@ -47,6 +54,8 @@
 }
 
 LiveSource::~LiveSource() {
+    free(mAESKey);
+    mAESKey = NULL;
 }
 
 status_t LiveSource::initCheck() const {
@@ -68,7 +77,77 @@
     return (double)rand() / RAND_MAX;
 }
 
-bool LiveSource::loadPlaylist(bool fetchMaster) {
+size_t LiveSource::getBandwidthIndex() {
+    if (mBandwidthItems.size() == 0) {
+        return 0;
+    }
+
+#if 1
+    int32_t bandwidthBps;
+    if (mSource != NULL && mSource->estimateBandwidth(&bandwidthBps)) {
+        LOGI("bandwidth estimated at %.2f kbps", bandwidthBps / 1024.0f);
+    } else {
+        LOGI("no bandwidth estimate.");
+        return 0;  // Pick the lowest bandwidth stream by default.
+    }
+
+    char value[PROPERTY_VALUE_MAX];
+    if (property_get("media.httplive.max-bw", value, NULL)) {
+        char *end;
+        long maxBw = strtoul(value, &end, 10);
+        if (end > value && *end == '\0') {
+            if (maxBw > 0 && bandwidthBps > maxBw) {
+                LOGV("bandwidth capped to %ld bps", maxBw);
+                bandwidthBps = maxBw;
+            }
+        }
+    }
+
+    // Consider only 80% of the available bandwidth usable.
+    bandwidthBps = (bandwidthBps * 8) / 10;
+
+    // Pick the highest bandwidth stream below or equal to estimated bandwidth.
+
+    size_t index = mBandwidthItems.size() - 1;
+    while (index > 0 && mBandwidthItems.itemAt(index).mBandwidth
+                            > (size_t)bandwidthBps) {
+        --index;
+    }
+#elif 0
+    // Change bandwidth at random()
+    size_t index = uniformRand() * mBandwidthItems.size();
+#elif 0
+    // There's a 50% chance to stay on the current bandwidth and
+    // a 50% chance to switch to the next higher bandwidth (wrapping around
+    // to lowest)
+    const size_t kMinIndex = 0;
+
+    size_t index;
+    if (mPrevBandwidthIndex < 0) {
+        index = kMinIndex;
+    } else if (uniformRand() < 0.5) {
+        index = (size_t)mPrevBandwidthIndex;
+    } else {
+        index = mPrevBandwidthIndex + 1;
+        if (index == mBandwidthItems.size()) {
+            index = kMinIndex;
+        }
+    }
+#elif 0
+    // Pick the highest bandwidth stream below or equal to 1.2 Mbit/sec
+
+    size_t index = mBandwidthItems.size() - 1;
+    while (index > 0 && mBandwidthItems.itemAt(index).mBandwidth > 1200000) {
+        --index;
+    }
+#else
+    size_t index = mBandwidthItems.size() - 1;  // Highest bandwidth stream
+#endif
+
+    return index;
+}
+
+bool LiveSource::loadPlaylist(bool fetchMaster, size_t bandwidthIndex) {
     mSignalDiscontinuity = false;
 
     mPlaylist.clear();
@@ -112,49 +191,35 @@
 
             mBandwidthItems.sort(SortByBandwidth);
 
+#if 1  // XXX
+            if (mBandwidthItems.size() > 1) {
+                // Remove the lowest bandwidth stream, this is sometimes
+                // an AAC program stream, which we don't support at this point.
+                mBandwidthItems.removeItemsAt(0);
+            }
+#endif
+
             for (size_t i = 0; i < mBandwidthItems.size(); ++i) {
                 const BandwidthItem &item = mBandwidthItems.itemAt(i);
                 LOGV("item #%d: %s", i, item.mURI.c_str());
             }
+
+            bandwidthIndex = getBandwidthIndex();
         }
     }
 
     if (mBandwidthItems.size() > 0) {
-#if 0
-        // Change bandwidth at random()
-        size_t index = uniformRand() * mBandwidthItems.size();
-#elif 0
-        // There's a 50% chance to stay on the current bandwidth and
-        // a 50% chance to switch to the next higher bandwidth (wrapping around
-        // to lowest)
-        size_t index;
-        if (uniformRand() < 0.5) {
-            index = mPrevBandwidthIndex < 0 ? 0 : (size_t)mPrevBandwidthIndex;
-        } else {
-            if (mPrevBandwidthIndex < 0) {
-                index = 0;
-            } else {
-                index = mPrevBandwidthIndex + 1;
-                if (index == mBandwidthItems.size()) {
-                    index = 0;
-                }
-            }
-        }
-#else
-        // Stay on the lowest bandwidth available.
-        size_t index = mBandwidthItems.size() - 1;  // Highest bandwidth stream
-#endif
+        mURL = mBandwidthItems.editItemAt(bandwidthIndex).mURI;
 
-        mURL = mBandwidthItems.editItemAt(index).mURI;
-
-        if (mPrevBandwidthIndex >= 0 && (size_t)mPrevBandwidthIndex != index) {
+        if (mPrevBandwidthIndex >= 0
+                && (size_t)mPrevBandwidthIndex != bandwidthIndex) {
             // If we switched streams because of bandwidth changes,
             // we'll signal this discontinuity by inserting a
             // special transport stream packet into the stream.
             mSignalDiscontinuity = true;
         }
 
-        mPrevBandwidthIndex = index;
+        mPrevBandwidthIndex = bandwidthIndex;
     } else {
         mURL = mMasterURL;
     }
@@ -199,12 +264,15 @@
     mOffsetBias += mSourceSize;
     mSourceSize = 0;
 
+    size_t bandwidthIndex = getBandwidthIndex();
+
     if (mLastFetchTimeUs < 0 || getNowUs() >= mLastFetchTimeUs + 15000000ll
-        || mPlaylistIndex == mPlaylist->size()) {
+        || mPlaylistIndex == mPlaylist->size()
+        || (ssize_t)bandwidthIndex != mPrevBandwidthIndex) {
         int32_t nextSequenceNumber =
             mPlaylistIndex + mFirstItemSequenceNumber;
 
-        if (!loadPlaylist(mLastFetchTimeUs < 0)) {
+        if (!loadPlaylist(mLastFetchTimeUs < 0, bandwidthIndex)) {
             LOGE("failed to reload playlist");
             return false;
         }
@@ -227,6 +295,10 @@
         mLastFetchTimeUs = getNowUs();
     }
 
+    if (!setupCipher()) {
+        return false;
+    }
+
     AString uri;
     sp<AMessage> itemMeta;
     CHECK(mPlaylist->itemAt(mPlaylistIndex, &uri, &itemMeta));
@@ -243,6 +315,121 @@
     }
 
     mPlaylistIndex++;
+
+    return true;
+}
+
+bool LiveSource::setupCipher() {
+    sp<AMessage> itemMeta;
+    bool found = false;
+    AString method;
+
+    for (ssize_t i = mPlaylistIndex; i >= 0; --i) {
+        AString uri;
+        CHECK(mPlaylist->itemAt(i, &uri, &itemMeta));
+
+        if (itemMeta->findString("cipher-method", &method)) {
+            found = true;
+            break;
+        }
+    }
+
+    if (!found) {
+        method = "NONE";
+    }
+
+    mStreamEncrypted = false;
+
+    if (method == "AES-128") {
+        AString keyURI;
+        if (!itemMeta->findString("cipher-uri", &keyURI)) {
+            LOGE("Missing key uri");
+            return false;
+        }
+
+        if (keyURI.size() >= 2
+                && keyURI.c_str()[0] == '"'
+                && keyURI.c_str()[keyURI.size() - 1] == '"') {
+            // Remove surrounding quotes.
+            AString tmp(keyURI, 1, keyURI.size() - 2);
+            keyURI = tmp;
+        }
+
+        ssize_t index = mAESKeyForURI.indexOfKey(keyURI);
+
+        sp<ABuffer> key;
+        if (index >= 0) {
+            key = mAESKeyForURI.valueAt(index);
+        } else {
+            key = new ABuffer(16);
+
+            sp<NuHTTPDataSource> keySource = new NuHTTPDataSource;
+            status_t err = keySource->connect(keyURI.c_str());
+
+            if (err == OK) {
+                size_t offset = 0;
+                while (offset < 16) {
+                    ssize_t n = keySource->readAt(
+                            offset, key->data() + offset, 16 - offset);
+                    if (n <= 0) {
+                        err = ERROR_IO;
+                        break;
+                    }
+
+                    offset += n;
+                }
+            }
+
+            if (err != OK) {
+                LOGE("failed to fetch cipher key from '%s'.", keyURI.c_str());
+                return false;
+            }
+
+            mAESKeyForURI.add(keyURI, key);
+        }
+
+        if (AES_set_decrypt_key(key->data(), 128, (AES_KEY *)mAESKey) != 0) {
+            LOGE("failed to set AES decryption key.");
+            return false;
+        }
+
+        AString iv;
+        if (itemMeta->findString("cipher-iv", &iv)) {
+            if ((!iv.startsWith("0x") && !iv.startsWith("0X"))
+                    || iv.size() != 16 * 2 + 2) {
+                LOGE("malformed cipher IV '%s'.", iv.c_str());
+                return false;
+            }
+
+            memset(mAESIVec, 0, sizeof(mAESIVec));
+            for (size_t i = 0; i < 16; ++i) {
+                char c1 = tolower(iv.c_str()[2 + 2 * i]);
+                char c2 = tolower(iv.c_str()[3 + 2 * i]);
+                if (!isxdigit(c1) || !isxdigit(c2)) {
+                    LOGE("malformed cipher IV '%s'.", iv.c_str());
+                    return false;
+                }
+                uint8_t nibble1 = isdigit(c1) ? c1 - '0' : c1 - 'a' + 10;
+                uint8_t nibble2 = isdigit(c2) ? c2 - '0' : c2 - 'a' + 10;
+
+                mAESIVec[i] = nibble1 << 4 | nibble2;
+            }
+        } else {
+            size_t seqNum = mPlaylistIndex + mFirstItemSequenceNumber;
+
+            memset(mAESIVec, 0, sizeof(mAESIVec));
+            mAESIVec[15] = seqNum & 0xff;
+            mAESIVec[14] = (seqNum >> 8) & 0xff;
+            mAESIVec[13] = (seqNum >> 16) & 0xff;
+            mAESIVec[12] = (seqNum >> 24) & 0xff;
+        }
+
+        mStreamEncrypted = true;
+    } else if (!(method == "NONE")) {
+        LOGE("Unsupported cipher method '%s'", method.c_str());
+        return false;
+    }
+
     return true;
 }
 
@@ -279,6 +466,7 @@
         return avail;
     }
 
+    bool done = false;
     size_t numRead = 0;
     while (numRead < size) {
         ssize_t n = mSource->readAt(
@@ -289,7 +477,44 @@
             break;
         }
 
+        if (mStreamEncrypted) {
+            size_t nmod = n % 16;
+            CHECK(nmod == 0);
+
+            sp<ABuffer> tmp = new ABuffer(n);
+
+            AES_cbc_encrypt((const unsigned char *)data + numRead,
+                            tmp->data(),
+                            n,
+                            (const AES_KEY *)mAESKey,
+                            mAESIVec,
+                            AES_DECRYPT);
+
+            if (mSourceSize == (off_t)(offset + numRead - delta + n)) {
+                // check for padding at the end of the file.
+
+                size_t pad = tmp->data()[n - 1];
+                CHECK_GT(pad, 0u);
+                CHECK_LE(pad, 16u);
+                CHECK_GE((size_t)n, pad);
+                for (size_t i = 0; i < pad; ++i) {
+                    CHECK_EQ((unsigned)tmp->data()[n - 1 - i], pad);
+                }
+
+                n -= pad;
+                mSourceSize -= pad;
+
+                done = true;
+            }
+
+            memcpy((uint8_t *)data + numRead, tmp->data(), n);
+        }
+
         numRead += n;
+
+        if (done) {
+            break;
+        }
     }
 
     return numRead;
@@ -359,19 +584,17 @@
         return false;
     }
 
-    size_t newPlaylistIndex = mFirstItemSequenceNumber + index;
-
-    if (newPlaylistIndex == mPlaylistIndex) {
+    if (index == mPlaylistIndex) {
         return false;
     }
 
-    mPlaylistIndex = newPlaylistIndex;
+    mPlaylistIndex = index;
+
+    LOGV("seeking to index %lld", index);
 
     switchToNext();
     mOffsetBias = 0;
 
-    LOGV("seeking to index %lld", index);
-
     return true;
 }
 
diff --git a/media/libstagefright/httplive/M3UParser.cpp b/media/libstagefright/httplive/M3UParser.cpp
index 90f3d6d..b166cc3 100644
--- a/media/libstagefright/httplive/M3UParser.cpp
+++ b/media/libstagefright/httplive/M3UParser.cpp
@@ -158,6 +158,11 @@
                     return ERROR_MALFORMED;
                 }
                 err = parseMetaData(line, &mMeta, "media-sequence");
+            } else if (line.startsWith("#EXT-X-KEY")) {
+                if (mIsVariantPlaylist) {
+                    return ERROR_MALFORMED;
+                }
+                err = parseCipherInfo(line, &itemMeta);
             } else if (line.startsWith("#EXT-X-ENDLIST")) {
                 mIsComplete = true;
             } else if (line.startsWith("#EXTINF")) {
@@ -292,6 +297,57 @@
 }
 
 // static
+status_t M3UParser::parseCipherInfo(
+        const AString &line, sp<AMessage> *meta) {
+    ssize_t colonPos = line.find(":");
+
+    if (colonPos < 0) {
+        return ERROR_MALFORMED;
+    }
+
+    size_t offset = colonPos + 1;
+
+    while (offset < line.size()) {
+        ssize_t end = line.find(",", offset);
+        if (end < 0) {
+            end = line.size();
+        }
+
+        AString attr(line, offset, end - offset);
+        attr.trim();
+
+        offset = end + 1;
+
+        ssize_t equalPos = attr.find("=");
+        if (equalPos < 0) {
+            continue;
+        }
+
+        AString key(attr, 0, equalPos);
+        key.trim();
+
+        AString val(attr, equalPos + 1, attr.size() - equalPos - 1);
+        val.trim();
+
+        LOGV("key=%s value=%s", key.c_str(), val.c_str());
+
+        key.tolower();
+
+        if (key == "method" || key == "uri" || key == "iv") {
+            if (meta->get() == NULL) {
+                *meta = new AMessage;
+            }
+
+            key.insert(AString("cipher-"), 0);
+
+            (*meta)->setString(key.c_str(), val.c_str(), val.size());
+        }
+    }
+
+    return OK;
+}
+
+// static
 status_t M3UParser::ParseInt32(const char *s, int32_t *x) {
     char *end;
     long lval = strtol(s, &end, 10);
diff --git a/media/libstagefright/include/LiveSource.h b/media/libstagefright/include/LiveSource.h
index 55dd45e..7ba1f44 100644
--- a/media/libstagefright/include/LiveSource.h
+++ b/media/libstagefright/include/LiveSource.h
@@ -21,6 +21,7 @@
 #include <media/stagefright/foundation/ABase.h>
 #include <media/stagefright/foundation/AString.h>
 #include <media/stagefright/DataSource.h>
+#include <utils/KeyedVector.h>
 #include <utils/Vector.h>
 
 namespace android {
@@ -72,14 +73,23 @@
     bool mSignalDiscontinuity;
     ssize_t mPrevBandwidthIndex;
 
+    void *mAESKey;
+    unsigned char mAESIVec[16];
+    bool mStreamEncrypted;
+
+    KeyedVector<AString, sp<ABuffer> > mAESKeyForURI;
+
     status_t fetchM3U(const char *url, sp<ABuffer> *buffer);
 
     static int SortByBandwidth(const BandwidthItem *a, const BandwidthItem *b);
 
     bool switchToNext();
-    bool loadPlaylist(bool fetchMaster);
+    bool loadPlaylist(bool fetchMaster, size_t bandwidthIndex);
     void determineSeekability();
 
+    size_t getBandwidthIndex();
+    bool setupCipher();
+
     DISALLOW_EVIL_CONSTRUCTORS(LiveSource);
 };
 
diff --git a/media/libstagefright/include/M3UParser.h b/media/libstagefright/include/M3UParser.h
index bd9eebe..531d184 100644
--- a/media/libstagefright/include/M3UParser.h
+++ b/media/libstagefright/include/M3UParser.h
@@ -66,6 +66,9 @@
     static status_t parseStreamInf(
             const AString &line, sp<AMessage> *meta);
 
+    static status_t parseCipherInfo(
+            const AString &line, sp<AMessage> *meta);
+
     static status_t ParseInt32(const char *s, int32_t *x);
 
     DISALLOW_EVIL_CONSTRUCTORS(M3UParser);
diff --git a/media/libstagefright/mpeg2ts/ATSParser.cpp b/media/libstagefright/mpeg2ts/ATSParser.cpp
index c88c6c1..f06a1bb 100644
--- a/media/libstagefright/mpeg2ts/ATSParser.cpp
+++ b/media/libstagefright/mpeg2ts/ATSParser.cpp
@@ -274,6 +274,8 @@
       mQueue(streamType == 0x1b
               ? ElementaryStreamQueue::H264 : ElementaryStreamQueue::AAC) {
     mBuffer->setRange(0, 0);
+
+    LOGV("new stream PID 0x%02x, type 0x%02x", elementaryPID, streamType);
 }
 
 ATSParser::Stream::~Stream() {
@@ -307,7 +309,8 @@
 }
 
 void ATSParser::Stream::signalDiscontinuity(bool isASeek) {
-    LOGV("Stream discontinuity");
+    isASeek = false;  // Always signal a "real" discontinuity
+
     mPayloadStarted = false;
     mBuffer->setRange(0, 0);
 
@@ -317,7 +320,9 @@
         // This is only a "minor" discontinuity, we stay within the same
         // bitstream.
 
-        mSource->clear();
+        if (mSource != NULL) {
+            mSource->clear();
+        }
         return;
     }
 
diff --git a/media/libstagefright/mpeg2ts/ESQueue.cpp b/media/libstagefright/mpeg2ts/ESQueue.cpp
index b0b9e66..f11b3c3 100644
--- a/media/libstagefright/mpeg2ts/ESQueue.cpp
+++ b/media/libstagefright/mpeg2ts/ESQueue.cpp
@@ -41,7 +41,10 @@
 }
 
 void ElementaryStreamQueue::clear() {
-    mBuffer->setRange(0, 0);
+    if (mBuffer != NULL) {
+        mBuffer->setRange(0, 0);
+    }
+
     mTimestamps.clear();
     mFormat.clear();
 }