Merge "stagefright: Fix seeking in MPEG4 container" into lmp-dev
diff --git a/media/libstagefright/MPEG4Extractor.cpp b/media/libstagefright/MPEG4Extractor.cpp
index 207acc8..19da6ee 100644
--- a/media/libstagefright/MPEG4Extractor.cpp
+++ b/media/libstagefright/MPEG4Extractor.cpp
@@ -3665,7 +3665,7 @@
 
         uint32_t sampleIndex;
         status_t err = mSampleTable->findSampleAtTime(
-                seekTimeUs * mTimescale / 1000000,
+                seekTimeUs, 1000000, mTimescale,
                 &sampleIndex, findFlags);
 
         if (mode == ReadOptions::SEEK_CLOSEST) {
diff --git a/media/libstagefright/OggExtractor.cpp b/media/libstagefright/OggExtractor.cpp
index 8c15929..0838004 100644
--- a/media/libstagefright/OggExtractor.cpp
+++ b/media/libstagefright/OggExtractor.cpp
@@ -320,14 +320,14 @@
     }
 
     size_t left = 0;
-    size_t right = mTableOfContents.size();
-    while (left < right) {
-        size_t center = left / 2 + right / 2 + (left & right & 1);
+    size_t right = mTableOfContents.size() - 1;
+    while (left <= right) {
+        size_t center = left + (right - left) / 2;
 
         const TOCEntry &entry = mTableOfContents.itemAt(center);
 
         if (timeUs < entry.mTimeUs) {
-            right = center;
+            right = center - 1;
         } else if (timeUs > entry.mTimeUs) {
             left = center + 1;
         } else {
diff --git a/media/libstagefright/SampleTable.cpp b/media/libstagefright/SampleTable.cpp
index 9a92805..c7e24fc 100644
--- a/media/libstagefright/SampleTable.cpp
+++ b/media/libstagefright/SampleTable.cpp
@@ -520,83 +520,72 @@
 }
 
 status_t SampleTable::findSampleAtTime(
-        uint32_t req_time, uint32_t *sample_index, uint32_t flags) {
+        uint64_t req_time, uint64_t scale_num, uint64_t scale_den,
+        uint32_t *sample_index, uint32_t flags) {
     buildSampleEntriesTable();
 
     uint32_t left = 0;
-    uint32_t right = mNumSampleSizes;
-    while (left < right) {
-        uint32_t center = (left + right) / 2;
-        uint32_t centerTime = mSampleTimeEntries[center].mCompositionTime;
+    uint32_t right = mNumSampleSizes - 1;
+    while (left <= right) {
+        uint32_t center = left + (right - left) / 2;
+        uint64_t centerTime =
+            getSampleTime(center, scale_num, scale_den);
 
         if (req_time < centerTime) {
-            right = center;
+            right = center - 1;
         } else if (req_time > centerTime) {
             left = center + 1;
         } else {
-            left = center;
-            break;
+            *sample_index = mSampleTimeEntries[center].mSampleIndex;
+            return OK;
         }
     }
 
-    if (left == mNumSampleSizes) {
-        if (flags == kFlagAfter) {
-            return ERROR_OUT_OF_RANGE;
-        }
-
-        --left;
-    }
-
     uint32_t closestIndex = left;
 
+    if (closestIndex == mNumSampleSizes) {
+        if (flags == kFlagAfter) {
+            return ERROR_OUT_OF_RANGE;
+        }
+        flags = kFlagBefore;
+    } else if (closestIndex == 0) {
+        if (flags == kFlagBefore) {
+            // normally we should return out of range, but that is
+            // treated as end-of-stream.  instead return first sample
+            //
+            // return ERROR_OUT_OF_RANGE;
+        }
+        flags = kFlagAfter;
+    }
+
     switch (flags) {
         case kFlagBefore:
         {
-            while (closestIndex > 0
-                    && mSampleTimeEntries[closestIndex].mCompositionTime
-                            > req_time) {
-                --closestIndex;
-            }
+            --closestIndex;
             break;
         }
 
         case kFlagAfter:
         {
-            while (closestIndex + 1 < mNumSampleSizes
-                    && mSampleTimeEntries[closestIndex].mCompositionTime
-                            < req_time) {
-                ++closestIndex;
-            }
+            // nothing to do
             break;
         }
 
         default:
         {
             CHECK(flags == kFlagClosest);
-
-            if (closestIndex > 0) {
-                // Check left neighbour and pick closest.
-                uint32_t absdiff1 =
-                    abs_difference(
-                            mSampleTimeEntries[closestIndex].mCompositionTime,
-                            req_time);
-
-                uint32_t absdiff2 =
-                    abs_difference(
-                            mSampleTimeEntries[closestIndex - 1].mCompositionTime,
-                            req_time);
-
-                if (absdiff1 > absdiff2) {
-                    closestIndex = closestIndex - 1;
-                }
+            // pick closest based on timestamp. use abs_difference for safety
+            if (abs_difference(
+                    getSampleTime(closestIndex, scale_num, scale_den), req_time) >
+                abs_difference(
+                    req_time, getSampleTime(closestIndex - 1, scale_num, scale_den))) {
+                --closestIndex;
             }
-
             break;
         }
     }
 
     *sample_index = mSampleTimeEntries[closestIndex].mSampleIndex;
-
     return OK;
 }
 
@@ -618,109 +607,85 @@
     }
 
     uint32_t left = 0;
-    uint32_t right = mNumSyncSamples;
-    while (left < right) {
+    uint32_t right = mNumSyncSamples - 1;
+    while (left <= right) {
         uint32_t center = left + (right - left) / 2;
         uint32_t x = mSyncSamples[center];
 
         if (start_sample_index < x) {
-            right = center;
+            right = center - 1;
         } else if (start_sample_index > x) {
             left = center + 1;
         } else {
-            left = center;
-            break;
+            *sample_index = x;
+            return OK;
         }
     }
+
     if (left == mNumSyncSamples) {
         if (flags == kFlagAfter) {
             ALOGE("tried to find a sync frame after the last one: %d", left);
             return ERROR_OUT_OF_RANGE;
         }
-        left = left - 1;
+        flags = kFlagBefore;
+    }
+    else if (left == 0) {
+        if (flags == kFlagBefore) {
+            ALOGE("tried to find a sync frame before the first one: %d", left);
+
+            // normally we should return out of range, but that is
+            // treated as end-of-stream.  instead seek to first sync
+            //
+            // return ERROR_OUT_OF_RANGE;
+        }
+        flags = kFlagAfter;
     }
 
-    // Now ssi[left] is the sync sample index just before (or at)
-    // start_sample_index.
-    // Also start_sample_index < ssi[left + 1], if left + 1 < mNumSyncSamples.
-
-    uint32_t x = mSyncSamples[left];
-
-    if (left + 1 < mNumSyncSamples) {
-        uint32_t y = mSyncSamples[left + 1];
-
-        // our sample lies between sync samples x and y.
-
-        status_t err = mSampleIterator->seekTo(start_sample_index);
-        if (err != OK) {
-            return err;
-        }
-
-        uint32_t sample_time = mSampleIterator->getSampleTime();
-
-        err = mSampleIterator->seekTo(x);
-        if (err != OK) {
-            return err;
-        }
-        uint32_t x_time = mSampleIterator->getSampleTime();
-
-        err = mSampleIterator->seekTo(y);
-        if (err != OK) {
-            return err;
-        }
-
-        uint32_t y_time = mSampleIterator->getSampleTime();
-
-        if (abs_difference(x_time, sample_time)
-                > abs_difference(y_time, sample_time)) {
-            // Pick the sync sample closest (timewise) to the start-sample.
-            x = y;
-            ++left;
-        }
-    }
-
+    // Now ssi[left - 1] <(=) start_sample_index <= ssi[left]
     switch (flags) {
         case kFlagBefore:
         {
-            if (x > start_sample_index) {
-                CHECK(left > 0);
-
-                x = mSyncSamples[left - 1];
-
-                if (x > start_sample_index) {
-                    // The table of sync sample indices was not sorted
-                    // properly.
-                    return ERROR_MALFORMED;
-                }
-            }
+            --left;
             break;
         }
-
         case kFlagAfter:
         {
-            if (x < start_sample_index) {
-                if (left + 1 >= mNumSyncSamples) {
-                    return ERROR_OUT_OF_RANGE;
-                }
-
-                x = mSyncSamples[left + 1];
-
-                if (x < start_sample_index) {
-                    // The table of sync sample indices was not sorted
-                    // properly.
-                    return ERROR_MALFORMED;
-                }
-            }
-
+            // nothing to do
             break;
         }
-
         default:
+        {
+            // this route is not used, but implement it nonetheless
+            CHECK(flags == kFlagClosest);
+
+            status_t err = mSampleIterator->seekTo(start_sample_index);
+            if (err != OK) {
+                return err;
+            }
+            uint32_t sample_time = mSampleIterator->getSampleTime();
+
+            err = mSampleIterator->seekTo(mSyncSamples[left]);
+            if (err != OK) {
+                return err;
+            }
+            uint32_t upper_time = mSampleIterator->getSampleTime();
+
+            err = mSampleIterator->seekTo(mSyncSamples[left - 1]);
+            if (err != OK) {
+                return err;
+            }
+            uint32_t lower_time = mSampleIterator->getSampleTime();
+
+            // use abs_difference for safety
+            if (abs_difference(upper_time, sample_time) >
+                abs_difference(sample_time, lower_time)) {
+                --left;
+            }
             break;
+        }
     }
 
-    *sample_index = x;
-
+    *sample_index = mSyncSamples[left];
     return OK;
 }
 
diff --git a/media/libstagefright/include/SampleTable.h b/media/libstagefright/include/SampleTable.h
index fe146f2..d06df7b 100644
--- a/media/libstagefright/include/SampleTable.h
+++ b/media/libstagefright/include/SampleTable.h
@@ -75,7 +75,8 @@
         kFlagClosest
     };
     status_t findSampleAtTime(
-            uint32_t req_time, uint32_t *sample_index, uint32_t flags);
+            uint64_t req_time, uint64_t scale_num, uint64_t scale_den,
+            uint32_t *sample_index, uint32_t flags);
 
     status_t findSyncSampleNear(
             uint32_t start_sample_index, uint32_t *sample_index,
@@ -138,6 +139,13 @@
 
     friend struct SampleIterator;
 
+    // normally we don't round
+    inline uint64_t getSampleTime(
+            size_t sample_index, uint64_t scale_num, uint64_t scale_den) const {
+        return (mSampleTimeEntries[sample_index].mCompositionTime
+            * scale_num) / scale_den;
+    }
+
     status_t getSampleSize_l(uint32_t sample_index, size_t *sample_size);
     uint32_t getCompositionTimeOffset(uint32_t sampleIndex);