Andy Hung | 79ccfda | 2023-01-30 11:58:44 -0800 | [diff] [blame] | 1 | /* |
| 2 | * Copyright (C) 2023 The Android Open Source Project |
| 3 | * |
| 4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | * you may not use this file except in compliance with the License. |
| 6 | * You may obtain a copy of the License at |
| 7 | * |
| 8 | * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | * |
| 10 | * Unless required by applicable law or agreed to in writing, software |
| 11 | * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | * See the License for the specific language governing permissions and |
| 14 | * limitations under the License. |
| 15 | */ |
| 16 | |
| 17 | #include "PosePredictor.h" |
| 18 | |
| 19 | namespace android::media { |
| 20 | |
| 21 | namespace { |
| 22 | #ifdef ENABLE_VERIFICATION |
| 23 | constexpr bool kEnableVerification = true; |
| 24 | constexpr std::array<int, 3> kLookAheadMs{ 50, 100, 200 }; |
| 25 | #else |
| 26 | constexpr bool kEnableVerification = false; |
| 27 | constexpr std::array<int, 0> kLookAheadMs{}; |
| 28 | #endif |
| 29 | |
| 30 | } // namespace |
| 31 | |
| 32 | void LeastSquaresPredictor::add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) |
| 33 | { |
| 34 | (void)twist; |
| 35 | mLastAtNs = atNs; |
| 36 | mLastPose = pose; |
| 37 | const auto q = pose.rotation(); |
| 38 | const double datNs = static_cast<double>(atNs); |
| 39 | mRw.add({datNs, q.w()}); |
| 40 | mRx.add({datNs, q.x()}); |
| 41 | mRy.add({datNs, q.y()}); |
| 42 | mRz.add({datNs, q.z()}); |
| 43 | } |
| 44 | |
| 45 | Pose3f LeastSquaresPredictor::predict(int64_t atNs) const |
| 46 | { |
| 47 | if (mRw.getN() < kMinimumSamplesForPrediction) return mLastPose; |
| 48 | |
| 49 | /* |
| 50 | * Using parametric form, we have q(t) = { w(t), x(t), y(t), z(t) }. |
| 51 | * We compute the least squares prediction of w, x, y, z. |
| 52 | */ |
| 53 | const double dLookahead = static_cast<double>(atNs); |
| 54 | Eigen::Quaternionf lsq( |
| 55 | mRw.getYFromX(dLookahead), |
| 56 | mRx.getYFromX(dLookahead), |
| 57 | mRy.getYFromX(dLookahead), |
| 58 | mRz.getYFromX(dLookahead)); |
| 59 | |
| 60 | /* |
| 61 | * We cheat here, since the result lsq is the least squares prediction |
| 62 | * in H (arbitrary quaternion), not the least squares prediction in |
| 63 | * SO(3) (unit quaternion). |
| 64 | * |
| 65 | * In other words, the result for lsq is most likely not a unit quaternion. |
| 66 | * To solve this, we normalize, thereby selecting the closest unit quaternion |
| 67 | * in SO(3) to the prediction in H. |
| 68 | */ |
| 69 | lsq.normalize(); |
| 70 | return Pose3f(lsq); |
| 71 | } |
| 72 | |
| 73 | void LeastSquaresPredictor::reset() { |
| 74 | mLastAtNs = {}; |
| 75 | mLastPose = {}; |
| 76 | mRw.reset(); |
| 77 | mRx.reset(); |
| 78 | mRy.reset(); |
| 79 | mRz.reset(); |
| 80 | } |
| 81 | |
| 82 | std::string LeastSquaresPredictor::toString(size_t index) const { |
| 83 | std::string s(index, ' '); |
| 84 | s.append("LeastSquaresPredictor using alpha: ") |
| 85 | .append(std::to_string(mAlpha)) |
| 86 | .append(" last pose: ") |
| 87 | .append(mLastPose.toString()) |
| 88 | .append("\n"); |
| 89 | return s; |
| 90 | } |
| 91 | |
| 92 | // Formatting |
| 93 | static inline std::vector<size_t> createDelimiterIdx(size_t predictors, size_t lookaheads) { |
| 94 | if (predictors == 0) return {}; |
| 95 | --predictors; |
| 96 | std::vector<size_t> delimiterIdx(predictors); |
| 97 | for (size_t i = 0; i < predictors; ++i) { |
| 98 | delimiterIdx[i] = (i + 1) * lookaheads; |
| 99 | } |
| 100 | return delimiterIdx; |
| 101 | } |
| 102 | |
| 103 | PosePredictor::PosePredictor() |
| 104 | : mPredictors{ // must match switch in getCurrentPredictor() |
| 105 | std::make_shared<LastPredictor>(), |
| 106 | std::make_shared<TwistPredictor>(), |
| 107 | std::make_shared<LeastSquaresPredictor>(), |
| 108 | } |
| 109 | , mLookaheadMs(kLookAheadMs.begin(), kLookAheadMs.end()) |
| 110 | , mVerifiers(std::size(mLookaheadMs) * std::size(mPredictors)) |
| 111 | , mDelimiterIdx(createDelimiterIdx(std::size(mPredictors), std::size(mLookaheadMs))) |
| 112 | , mPredictionRecorder( |
| 113 | std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */, |
| 114 | mDelimiterIdx) |
| 115 | , mPredictionDurableRecorder( |
| 116 | std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */, |
| 117 | mDelimiterIdx) |
| 118 | { |
| 119 | } |
| 120 | |
| 121 | Pose3f PosePredictor::predict( |
| 122 | int64_t timestampNs, const Pose3f& pose, const Twist3f& twist, float predictionDurationNs) |
| 123 | { |
| 124 | if (timestampNs - mLastTimestampNs > kMaximumSampleIntervalBeforeResetNs) { |
| 125 | for (const auto& predictor : mPredictors) { |
| 126 | predictor->reset(); |
| 127 | } |
| 128 | ++mResets; |
| 129 | } |
| 130 | mLastTimestampNs = timestampNs; |
| 131 | |
| 132 | auto selectedPredictor = getCurrentPredictor(); |
| 133 | if constexpr (kEnableVerification) { |
| 134 | // Update all Predictors |
| 135 | for (const auto& predictor : mPredictors) { |
| 136 | predictor->add(timestampNs, pose, twist); |
| 137 | } |
| 138 | |
| 139 | // Update Verifiers and calculate errors |
| 140 | std::vector<float> error(std::size(mVerifiers)); |
| 141 | for (size_t i = 0; i < mLookaheadMs.size(); ++i) { |
| 142 | constexpr float RADIAN_TO_DEGREES = 180 / M_PI; |
| 143 | const int64_t atNs = |
| 144 | timestampNs + mLookaheadMs[i] * PosePredictorVerifier::kMillisToNanos; |
| 145 | |
| 146 | for (size_t j = 0; j < mPredictors.size(); ++j) { |
| 147 | const size_t idx = i * std::size(mPredictors) + j; |
| 148 | mVerifiers[idx].verifyActualPose(timestampNs, pose); |
| 149 | mVerifiers[idx].addPredictedPose(atNs, mPredictors[j]->predict(atNs)); |
| 150 | error[idx] = RADIAN_TO_DEGREES * mVerifiers[idx].lastError(); |
| 151 | } |
| 152 | } |
| 153 | // Record errors |
| 154 | mPredictionRecorder.record(error); |
| 155 | mPredictionDurableRecorder.record(error); |
| 156 | } else /* constexpr */ { |
| 157 | selectedPredictor->add(timestampNs, pose, twist); |
| 158 | } |
| 159 | |
| 160 | // Deliver prediction |
| 161 | const int64_t predictionTimeNs = timestampNs + (int64_t)predictionDurationNs; |
| 162 | return selectedPredictor->predict(predictionTimeNs); |
| 163 | } |
| 164 | |
| 165 | void PosePredictor::setPosePredictorType(PosePredictorType type) { |
| 166 | if (!isValidPosePredictorType(type)) return; |
| 167 | if (type == mSetType) return; |
| 168 | mSetType = type; |
| 169 | if (type == android::media::PosePredictorType::AUTO) { |
| 170 | type = android::media::PosePredictorType::LEAST_SQUARES; |
| 171 | } |
| 172 | if (type != mCurrentType) { |
| 173 | mCurrentType = type; |
| 174 | if constexpr (!kEnableVerification) { |
| 175 | // Verification keeps all predictors up-to-date. |
| 176 | // If we don't enable verification, we must reset the current predictor. |
| 177 | getCurrentPredictor()->reset(); |
| 178 | } |
| 179 | } |
| 180 | } |
| 181 | |
| 182 | std::string PosePredictor::toString(size_t index) const { |
| 183 | std::string prefixSpace(index, ' '); |
| 184 | std::string ss(prefixSpace); |
| 185 | ss.append("PosePredictor:\n") |
| 186 | .append(prefixSpace) |
| 187 | .append(" Current Prediction Type: ") |
| 188 | .append(android::media::toString(mCurrentType)) |
| 189 | .append("\n") |
| 190 | .append(prefixSpace) |
| 191 | .append(" Resets: ") |
| 192 | .append(std::to_string(mResets)) |
| 193 | .append("\n") |
| 194 | .append(getCurrentPredictor()->toString(index + 1)); |
| 195 | if constexpr (kEnableVerification) { |
| 196 | // dump verification |
| 197 | ss.append(prefixSpace) |
| 198 | .append(" Prediction abs error (L1) degrees [ type (last twist least-squares) x ( "); |
| 199 | for (size_t i = 0; i < mLookaheadMs.size(); ++i) { |
| 200 | if (i > 0) ss.append(" : "); |
| 201 | ss.append(std::to_string(mLookaheadMs[i])); |
| 202 | } |
| 203 | std::vector<float> cumulativeAverageErrors(std::size(mVerifiers)); |
| 204 | for (size_t i = 0; i < cumulativeAverageErrors.size(); ++i) { |
| 205 | cumulativeAverageErrors[i] = mVerifiers[i].cumulativeAverageError(); |
| 206 | } |
| 207 | ss.append(" ) ms ]\n") |
| 208 | .append(prefixSpace) |
| 209 | .append(" Cumulative Average Error:\n") |
| 210 | .append(prefixSpace) |
| 211 | .append(" ") |
| 212 | .append(VectorRecorder::toString(cumulativeAverageErrors, mDelimiterIdx, "%.3g")) |
| 213 | .append("\n") |
| 214 | .append(prefixSpace) |
| 215 | .append(" PerMinuteHistory:\n") |
| 216 | .append(mPredictionDurableRecorder.toString(index + 3)) |
| 217 | .append(prefixSpace) |
| 218 | .append(" PerSecondHistory:\n") |
| 219 | .append(mPredictionRecorder.toString(index + 3)); |
| 220 | } |
| 221 | return ss; |
| 222 | } |
| 223 | |
| 224 | std::shared_ptr<PredictorBase> PosePredictor::getCurrentPredictor() const { |
| 225 | // we don't use a map here, we look up directly |
| 226 | switch (mCurrentType) { |
| 227 | default: |
| 228 | case android::media::PosePredictorType::LAST: |
| 229 | return mPredictors[0]; |
| 230 | case android::media::PosePredictorType::TWIST: |
| 231 | return mPredictors[1]; |
| 232 | case android::media::PosePredictorType::AUTO: // shouldn't occur here. |
| 233 | case android::media::PosePredictorType::LEAST_SQUARES: |
| 234 | return mPredictors[2]; |
| 235 | } |
| 236 | } |
| 237 | |
| 238 | } // namespace android::media |