MPEG4Writer: adjust the whole CTTS table

- refactor ListTableEntries to use constant entryCapacity

Bug: 30182452
Change-Id: Ib883a8547e198fab85c63ac756117e8e11384c73
diff --git a/media/libstagefright/MPEG4Writer.cpp b/media/libstagefright/MPEG4Writer.cpp
index ec534ef..66c4ffd 100644
--- a/media/libstagefright/MPEG4Writer.cpp
+++ b/media/libstagefright/MPEG4Writer.cpp
@@ -30,6 +30,8 @@
 
 #include <utils/Log.h>
 
+#include <functional>
+
 #include <media/stagefright/foundation/ADebug.h>
 #include <media/stagefright/foundation/AMessage.h>
 #include <media/stagefright/foundation/AUtils.h>
@@ -122,18 +124,18 @@
     };
 
     // A helper class to handle faster write box with table entries
-    template<class TYPE>
+    template<class TYPE, unsigned ENTRY_SIZE>
+    // ENTRY_SIZE: # of values in each entry
     struct ListTableEntries {
-        ListTableEntries(uint32_t elementCapacity, uint32_t entryCapacity)
+        static_assert(ENTRY_SIZE > 0, "ENTRY_SIZE must be positive");
+        ListTableEntries(uint32_t elementCapacity)
             : mElementCapacity(elementCapacity),
-            mEntryCapacity(entryCapacity),
             mTotalNumTableEntries(0),
             mNumValuesInCurrEntry(0),
             mCurrTableEntriesElement(NULL) {
             CHECK_GT(mElementCapacity, 0);
-            CHECK_GT(mEntryCapacity, 0);
             // Ensure no integer overflow on allocation in add().
-            CHECK_LT(mEntryCapacity, UINT32_MAX / mElementCapacity);
+            CHECK_LT(ENTRY_SIZE, UINT32_MAX / mElementCapacity);
         }
 
         // Free the allocated memory.
@@ -150,10 +152,10 @@
         // @arg value must be in network byte order
         // @arg pos location the value must be in.
         void set(const TYPE& value, uint32_t pos) {
-            CHECK_LT(pos, mTotalNumTableEntries * mEntryCapacity);
+            CHECK_LT(pos, mTotalNumTableEntries * ENTRY_SIZE);
 
             typename List<TYPE *>::iterator it = mTableEntryList.begin();
-            uint32_t iterations = (pos / (mElementCapacity * mEntryCapacity));
+            uint32_t iterations = (pos / (mElementCapacity * ENTRY_SIZE));
             while (it != mTableEntryList.end() && iterations > 0) {
                 ++it;
                 --iterations;
@@ -161,7 +163,7 @@
             CHECK(it != mTableEntryList.end());
             CHECK_EQ(iterations, 0);
 
-            (*it)[(pos % (mElementCapacity * mEntryCapacity))] = value;
+            (*it)[(pos % (mElementCapacity * ENTRY_SIZE))] = value;
         }
 
         // Get the value at the given position by the given value.
@@ -169,12 +171,12 @@
         // @arg pos location the value must be in.
         // @return true if a value is found.
         bool get(TYPE& value, uint32_t pos) const {
-            if (pos >= mTotalNumTableEntries * mEntryCapacity) {
+            if (pos >= mTotalNumTableEntries * ENTRY_SIZE) {
                 return false;
             }
 
             typename List<TYPE *>::iterator it = mTableEntryList.begin();
-            uint32_t iterations = (pos / (mElementCapacity * mEntryCapacity));
+            uint32_t iterations = (pos / (mElementCapacity * ENTRY_SIZE));
             while (it != mTableEntryList.end() && iterations > 0) {
                 ++it;
                 --iterations;
@@ -182,27 +184,42 @@
             CHECK(it != mTableEntryList.end());
             CHECK_EQ(iterations, 0);
 
-            value = (*it)[(pos % (mElementCapacity * mEntryCapacity))];
+            value = (*it)[(pos % (mElementCapacity * ENTRY_SIZE))];
             return true;
         }
 
+        // adjusts all values by |adjust(value)|
+        void adjustEntries(
+                std::function<void(size_t /* ix */, TYPE(& /* entry */)[ENTRY_SIZE])> update) {
+            size_t nEntries = mTotalNumTableEntries + mNumValuesInCurrEntry / ENTRY_SIZE;
+            size_t ix = 0;
+            for (TYPE *entryArray : mTableEntryList) {
+                size_t num = std::min(nEntries, (size_t)mElementCapacity);
+                for (size_t i = 0; i < num; ++i) {
+                    update(ix++, (TYPE(&)[ENTRY_SIZE])(*entryArray));
+                    entryArray += ENTRY_SIZE;
+                }
+                nEntries -= num;
+            }
+        }
+
         // Store a single value.
         // @arg value must be in network byte order.
         void add(const TYPE& value) {
             CHECK_LT(mNumValuesInCurrEntry, mElementCapacity);
             uint32_t nEntries = mTotalNumTableEntries % mElementCapacity;
-            uint32_t nValues  = mNumValuesInCurrEntry % mEntryCapacity;
+            uint32_t nValues  = mNumValuesInCurrEntry % ENTRY_SIZE;
             if (nEntries == 0 && nValues == 0) {
-                mCurrTableEntriesElement = new TYPE[mEntryCapacity * mElementCapacity];
+                mCurrTableEntriesElement = new TYPE[ENTRY_SIZE * mElementCapacity];
                 CHECK(mCurrTableEntriesElement != NULL);
                 mTableEntryList.push_back(mCurrTableEntriesElement);
             }
 
-            uint32_t pos = nEntries * mEntryCapacity + nValues;
+            uint32_t pos = nEntries * ENTRY_SIZE + nValues;
             mCurrTableEntriesElement[pos] = value;
 
             ++mNumValuesInCurrEntry;
-            if ((mNumValuesInCurrEntry % mEntryCapacity) == 0) {
+            if ((mNumValuesInCurrEntry % ENTRY_SIZE) == 0) {
                 ++mTotalNumTableEntries;
                 mNumValuesInCurrEntry = 0;
             }
@@ -213,17 +230,17 @@
         // 2. followed by the values in the table enties in order
         // @arg writer the writer to actual write to the storage
         void write(MPEG4Writer *writer) const {
-            CHECK_EQ(mNumValuesInCurrEntry % mEntryCapacity, 0);
+            CHECK_EQ(mNumValuesInCurrEntry % ENTRY_SIZE, 0);
             uint32_t nEntries = mTotalNumTableEntries;
             writer->writeInt32(nEntries);
             for (typename List<TYPE *>::iterator it = mTableEntryList.begin();
                 it != mTableEntryList.end(); ++it) {
                 CHECK_GT(nEntries, 0);
                 if (nEntries >= mElementCapacity) {
-                    writer->write(*it, sizeof(TYPE) * mEntryCapacity, mElementCapacity);
+                    writer->write(*it, sizeof(TYPE) * ENTRY_SIZE, mElementCapacity);
                     nEntries -= mElementCapacity;
                 } else {
-                    writer->write(*it, sizeof(TYPE) * mEntryCapacity, nEntries);
+                    writer->write(*it, sizeof(TYPE) * ENTRY_SIZE, nEntries);
                     break;
                 }
             }
@@ -234,9 +251,8 @@
 
     private:
         uint32_t         mElementCapacity;  // # entries in an element
-        uint32_t         mEntryCapacity;    // # of values in each entry
         uint32_t         mTotalNumTableEntries;
-        uint32_t         mNumValuesInCurrEntry;  // up to mEntryCapacity
+        uint32_t         mNumValuesInCurrEntry;  // up to ENTRY_SIZE
         TYPE             *mCurrTableEntriesElement;
         mutable List<TYPE *>     mTableEntryList;
 
@@ -271,14 +287,14 @@
     List<MediaBuffer *> mChunkSamples;
 
     bool                mSamplesHaveSameSize;
-    ListTableEntries<uint32_t> *mStszTableEntries;
+    ListTableEntries<uint32_t, 1> *mStszTableEntries;
 
-    ListTableEntries<uint32_t> *mStcoTableEntries;
-    ListTableEntries<off64_t> *mCo64TableEntries;
-    ListTableEntries<uint32_t> *mStscTableEntries;
-    ListTableEntries<uint32_t> *mStssTableEntries;
-    ListTableEntries<uint32_t> *mSttsTableEntries;
-    ListTableEntries<uint32_t> *mCttsTableEntries;
+    ListTableEntries<uint32_t, 1> *mStcoTableEntries;
+    ListTableEntries<off64_t, 1> *mCo64TableEntries;
+    ListTableEntries<uint32_t, 3> *mStscTableEntries;
+    ListTableEntries<uint32_t, 1> *mStssTableEntries;
+    ListTableEntries<uint32_t, 2> *mSttsTableEntries;
+    ListTableEntries<uint32_t, 2> *mCttsTableEntries;
 
     int64_t mMinCttsOffsetTimeUs;
     int64_t mMaxCttsOffsetTimeUs;
@@ -1524,13 +1540,13 @@
       mTrackDurationUs(0),
       mEstimatedTrackSizeBytes(0),
       mSamplesHaveSameSize(true),
-      mStszTableEntries(new ListTableEntries<uint32_t>(1000, 1)),
-      mStcoTableEntries(new ListTableEntries<uint32_t>(1000, 1)),
-      mCo64TableEntries(new ListTableEntries<off64_t>(1000, 1)),
-      mStscTableEntries(new ListTableEntries<uint32_t>(1000, 3)),
-      mStssTableEntries(new ListTableEntries<uint32_t>(1000, 1)),
-      mSttsTableEntries(new ListTableEntries<uint32_t>(1000, 2)),
-      mCttsTableEntries(new ListTableEntries<uint32_t>(1000, 2)),
+      mStszTableEntries(new ListTableEntries<uint32_t, 1>(1000)),
+      mStcoTableEntries(new ListTableEntries<uint32_t, 1>(1000)),
+      mCo64TableEntries(new ListTableEntries<off64_t, 1>(1000)),
+      mStscTableEntries(new ListTableEntries<uint32_t, 3>(1000)),
+      mStssTableEntries(new ListTableEntries<uint32_t, 1>(1000)),
+      mSttsTableEntries(new ListTableEntries<uint32_t, 2>(1000)),
+      mCttsTableEntries(new ListTableEntries<uint32_t, 2>(1000)),
       mCodecSpecificData(NULL),
       mCodecSpecificDataSize(0),
       mGotAllCodecSpecificData(false),
@@ -3382,10 +3398,12 @@
 
     mOwner->beginBox("ctts");
     mOwner->writeInt32(0);  // version=0, flags=0
-    uint32_t duration;
-    CHECK(mCttsTableEntries->get(duration, 1));
-    duration = htonl(duration);  // Back host byte order
-    mCttsTableEntries->set(htonl(duration + getStartTimeOffsetScaledTime() - mMinCttsOffsetTimeUs), 1);
+    uint32_t delta = mMinCttsOffsetTimeUs - getStartTimeOffsetScaledTime();
+    mCttsTableEntries->adjustEntries([delta](size_t /* ix */, uint32_t (&value)[2]) {
+        // entries are <count, ctts> pairs; adjust only ctts
+        uint32_t duration = htonl(value[1]); // back to host byte order
+        value[1] = htonl(duration - delta);
+    });
     mCttsTableEntries->write(mOwner);
     mOwner->endBox();  // ctts
 }