/*
 * Copyright (C) 2016 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.server.net;

import static android.app.usage.NetworkStatsManager.MIN_THRESHOLD_BYTES;

import android.annotation.NonNull;
import android.app.usage.NetworkStatsManager;
import android.content.Context;
import android.content.pm.PackageManager;
import android.net.DataUsageRequest;
import android.net.NetworkIdentitySet;
import android.net.NetworkStack;
import android.net.NetworkStats;
import android.net.NetworkStatsAccess;
import android.net.NetworkStatsCollection;
import android.net.NetworkStatsHistory;
import android.net.NetworkTemplate;
import android.net.netstats.IUsageCallback;
import android.os.Handler;
import android.os.HandlerThread;
import android.os.IBinder;
import android.os.Looper;
import android.os.Message;
import android.os.Process;
import android.os.RemoteException;
import android.util.ArrayMap;
import android.util.IndentingPrintWriter;
import android.util.Log;
import android.util.SparseArray;

import com.android.internal.annotations.VisibleForTesting;
import com.android.net.module.util.PerUidCounter;

import java.util.concurrent.atomic.AtomicInteger;

/**
 * Manages observers of {@link NetworkStats}. Allows observers to be notified when
 * data usage has been reported in {@link NetworkStatsService}. An observer can set
 * a threshold of how much data it cares about to be notified.
 */
class NetworkStatsObservers {
    private static final String TAG = "NetworkStatsObservers";
    private static final boolean LOG = true;
    private static final boolean LOGV = false;

    private static final int MSG_REGISTER = 1;
    private static final int MSG_UNREGISTER = 2;
    private static final int MSG_UPDATE_STATS = 3;

    private static final int DUMP_USAGE_REQUESTS_COUNT = 200;

    // The maximum number of request allowed per uid before an exception is thrown.
    @VisibleForTesting
    static final int MAX_REQUESTS_PER_UID = 100;

    // All access to this map must be done from the handler thread.
    // indexed by DataUsageRequest#requestId
    private final SparseArray<RequestInfo> mDataUsageRequests = new SparseArray<>();

    // Request counters per uid, this is thread safe.
    private final PerUidCounter mDataUsageRequestsPerUid = new PerUidCounter(MAX_REQUESTS_PER_UID);

    // Sequence number of DataUsageRequests
    private final AtomicInteger mNextDataUsageRequestId = new AtomicInteger();

    // Lazily instantiated when an observer is registered.
    private volatile Handler mHandler;

    /**
     * Creates a wrapper that contains the caller context and a normalized request.
     * The request should be returned to the caller app, and the wrapper should be sent to this
     * object through #addObserver by the service handler.
     *
     * <p>It will register the observer asynchronously, so it is safe to call from any thread.
     *
     * @return the normalized request wrapped within {@link RequestInfo}.
     */
    public DataUsageRequest register(@NonNull Context context,
            @NonNull DataUsageRequest inputRequest, @NonNull IUsageCallback callback,
            int callingPid, int callingUid, @NonNull String callingPackage,
            @NetworkStatsAccess.Level int accessLevel) {
        DataUsageRequest request = buildRequest(context, inputRequest, callingUid);
        RequestInfo requestInfo = buildRequestInfo(request, callback, callingPid, callingUid,
                callingPackage, accessLevel);
        if (LOG) Log.d(TAG, "Registering observer for " + requestInfo);
        mDataUsageRequestsPerUid.incrementCountOrThrow(callingUid);

        getHandler().sendMessage(mHandler.obtainMessage(MSG_REGISTER, requestInfo));
        return request;
    }

    /**
     * Unregister a data usage observer.
     *
     * <p>It will unregister the observer asynchronously, so it is safe to call from any thread.
     */
    public void unregister(DataUsageRequest request, int callingUid) {
        getHandler().sendMessage(mHandler.obtainMessage(MSG_UNREGISTER, callingUid, 0 /* ignore */,
                request));
    }

    /**
     * Updates data usage statistics of registered observers and notifies if limits are reached.
     *
     * <p>It will update stats asynchronously, so it is safe to call from any thread.
     */
    public void updateStats(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
                ArrayMap<String, NetworkIdentitySet> activeIfaces,
                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
                long currentTime) {
        StatsContext statsContext = new StatsContext(xtSnapshot, uidSnapshot, activeIfaces,
                activeUidIfaces, currentTime);
        getHandler().sendMessage(mHandler.obtainMessage(MSG_UPDATE_STATS, statsContext));
    }

    private Handler getHandler() {
        if (mHandler == null) {
            synchronized (this) {
                if (mHandler == null) {
                    if (LOGV) Log.v(TAG, "Creating handler");
                    mHandler = new Handler(getHandlerLooperLocked(), mHandlerCallback);
                }
            }
        }
        return mHandler;
    }

    @VisibleForTesting
    protected Looper getHandlerLooperLocked() {
        // TODO: Currently, callbacks are dispatched on this thread if the caller register
        //  callback without supplying a Handler. To ensure that the service handler thread
        //  is not blocked by client code, the observers must create their own thread. Once
        //  all callbacks are dispatched outside of the handler thread, the service handler
        //  thread can be used here.
        HandlerThread handlerThread = new HandlerThread(TAG);
        handlerThread.start();
        return handlerThread.getLooper();
    }

    private Handler.Callback mHandlerCallback = new Handler.Callback() {
        @Override
        public boolean handleMessage(Message msg) {
            switch (msg.what) {
                case MSG_REGISTER: {
                    handleRegister((RequestInfo) msg.obj);
                    return true;
                }
                case MSG_UNREGISTER: {
                    handleUnregister((DataUsageRequest) msg.obj, msg.arg1 /* callingUid */);
                    return true;
                }
                case MSG_UPDATE_STATS: {
                    handleUpdateStats((StatsContext) msg.obj);
                    return true;
                }
                default: {
                    return false;
                }
            }
        }
    };

    /**
     * Adds a {@link RequestInfo} as an observer.
     * Should only be called from the handler thread otherwise there will be a race condition
     * on mDataUsageRequests.
     */
    private void handleRegister(RequestInfo requestInfo) {
        mDataUsageRequests.put(requestInfo.mRequest.requestId, requestInfo);
    }

    /**
     * Removes a {@link DataUsageRequest} if the calling uid is authorized.
     * Should only be called from the handler thread otherwise there will be a race condition
     * on mDataUsageRequests.
     */
    private void handleUnregister(DataUsageRequest request, int callingUid) {
        RequestInfo requestInfo;
        requestInfo = mDataUsageRequests.get(request.requestId);
        if (requestInfo == null) {
            if (LOG) Log.d(TAG, "Trying to unregister unknown request " + request);
            return;
        }
        if (Process.SYSTEM_UID != callingUid && requestInfo.mCallingUid != callingUid) {
            Log.w(TAG, "Caller uid " + callingUid + " is not owner of " + request);
            return;
        }

        if (LOG) Log.d(TAG, "Unregistering " + requestInfo);
        mDataUsageRequests.remove(request.requestId);
        mDataUsageRequestsPerUid.decrementCountOrThrow(requestInfo.mCallingUid);
        requestInfo.unlinkDeathRecipient();
        requestInfo.callCallback(NetworkStatsManager.CALLBACK_RELEASED);
    }

    private void handleUpdateStats(StatsContext statsContext) {
        if (mDataUsageRequests.size() == 0) {
            return;
        }

        for (int i = 0; i < mDataUsageRequests.size(); i++) {
            RequestInfo requestInfo = mDataUsageRequests.valueAt(i);
            requestInfo.updateStats(statsContext);
        }
    }

    private DataUsageRequest buildRequest(Context context, DataUsageRequest request,
                int callingUid) {
        // For non-NETWORK_STACK permission uid, cap the minimum threshold to a safe default to
        // avoid too many callbacks.
        final long thresholdInBytes = (context.checkPermission(
                NetworkStack.PERMISSION_MAINLINE_NETWORK_STACK, Process.myPid(), callingUid)
                == PackageManager.PERMISSION_GRANTED ? request.thresholdInBytes
                : Math.max(MIN_THRESHOLD_BYTES, request.thresholdInBytes));
        if (thresholdInBytes > request.thresholdInBytes) {
            Log.w(TAG, "Threshold was too low for " + request
                    + ". Overriding to a safer default of " + thresholdInBytes + " bytes");
        }
        return new DataUsageRequest(mNextDataUsageRequestId.incrementAndGet(),
                request.template, thresholdInBytes);
    }

    private RequestInfo buildRequestInfo(DataUsageRequest request, IUsageCallback callback,
            int callingPid, int callingUid, @NonNull String callingPackage,
            @NetworkStatsAccess.Level int accessLevel) {
        if (accessLevel <= NetworkStatsAccess.Level.USER) {
            return new UserUsageRequestInfo(this, request, callback, callingPid,
                    callingUid, callingPackage, accessLevel);
        } else {
            // Safety check in case a new access level is added and we forgot to update this
            if (accessLevel < NetworkStatsAccess.Level.DEVICESUMMARY) {
                throw new IllegalArgumentException(
                        "accessLevel " + accessLevel + " is less than DEVICESUMMARY.");
            }
            return new NetworkUsageRequestInfo(this, request, callback, callingPid,
                    callingUid, callingPackage, accessLevel);
        }
    }

    /**
     * Tracks information relevant to a data usage observer.
     * It will notice when the calling process dies so we can self-expire.
     */
    private abstract static class RequestInfo implements IBinder.DeathRecipient {
        private final NetworkStatsObservers mStatsObserver;
        protected final DataUsageRequest mRequest;
        private final IUsageCallback mCallback;
        protected final int mCallingPid;
        protected final int mCallingUid;
        protected final String mCallingPackage;
        protected final @NetworkStatsAccess.Level int mAccessLevel;
        protected NetworkStatsRecorder mRecorder;
        protected NetworkStatsCollection mCollection;

        RequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
                IUsageCallback callback, int callingPid, int callingUid,
                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
            mStatsObserver = statsObserver;
            mRequest = request;
            mCallback = callback;
            mCallingPid = callingPid;
            mCallingUid = callingUid;
            mCallingPackage = callingPackage;
            mAccessLevel = accessLevel;

            try {
                mCallback.asBinder().linkToDeath(this, 0);
            } catch (RemoteException e) {
                binderDied();
            }
        }

        @Override
        public void binderDied() {
            if (LOGV) {
                Log.v(TAG, "RequestInfo binderDied(" + mRequest + ", " + mCallback + ")");
            }
            mStatsObserver.unregister(mRequest, Process.SYSTEM_UID);
            callCallback(NetworkStatsManager.CALLBACK_RELEASED);
        }

        @Override
        public String toString() {
            return "RequestInfo from pid/uid:" + mCallingPid + "/" + mCallingUid
                    + "(" + mCallingPackage + ")"
                    + " for " + mRequest + " accessLevel:" + mAccessLevel;
        }

        private void unlinkDeathRecipient() {
            mCallback.asBinder().unlinkToDeath(this, 0);
        }

        /**
         * Update stats given the samples and interface to identity mappings.
         */
        private void updateStats(StatsContext statsContext) {
            if (mRecorder == null) {
                // First run; establish baseline stats
                resetRecorder();
                recordSample(statsContext);
                return;
            }
            recordSample(statsContext);

            if (checkStats()) {
                resetRecorder();
                callCallback(NetworkStatsManager.CALLBACK_LIMIT_REACHED);
            }
        }

        private void callCallback(int callbackType) {
            try {
                if (LOGV) {
                    Log.v(TAG, "sending notification " + callbackTypeToName(callbackType)
                            + " for " + mRequest);
                }
                switch (callbackType) {
                    case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
                        mCallback.onThresholdReached(mRequest);
                        break;
                    case NetworkStatsManager.CALLBACK_RELEASED:
                        mCallback.onCallbackReleased(mRequest);
                        break;
                }
            } catch (RemoteException e) {
                // May occur naturally in the race of binder death.
                Log.w(TAG, "RemoteException caught trying to send a callback msg for " + mRequest);
            }
        }

        private void resetRecorder() {
            mRecorder = new NetworkStatsRecorder();
            mCollection = mRecorder.getSinceBoot();
        }

        protected abstract boolean checkStats();

        protected abstract void recordSample(StatsContext statsContext);

        private String callbackTypeToName(int callbackType) {
            switch (callbackType) {
                case NetworkStatsManager.CALLBACK_LIMIT_REACHED:
                    return "LIMIT_REACHED";
                case NetworkStatsManager.CALLBACK_RELEASED:
                    return "RELEASED";
                default:
                    return "UNKNOWN";
            }
        }
    }

    private static class NetworkUsageRequestInfo extends RequestInfo {
        NetworkUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
                IUsageCallback callback, int callingPid, int callingUid,
                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
            super(statsObserver, request, callback, callingPid, callingUid, callingPackage,
                    accessLevel);
        }

        @Override
        protected boolean checkStats() {
            long bytesSoFar = getTotalBytesForNetwork(mRequest.template);
            if (LOGV) {
                Log.v(TAG, bytesSoFar + " bytes so far since notification for "
                        + mRequest.template);
            }
            if (bytesSoFar > mRequest.thresholdInBytes) {
                return true;
            }
            return false;
        }

        @Override
        protected void recordSample(StatsContext statsContext) {
            // Recorder does not need to be locked in this context since only the handler
            // thread will update it. We pass a null VPN array because usage is aggregated by uid
            // for this snapshot, so VPN traffic can't be reattributed to responsible apps.
            mRecorder.recordSnapshotLocked(statsContext.mXtSnapshot, statsContext.mActiveIfaces,
                    statsContext.mCurrentTime);
        }

        /**
         * Reads stats matching the given template. {@link NetworkStatsCollection} will aggregate
         * over all buckets, which in this case should be only one since we built it big enough
         * that it will outlive the caller. If it doesn't, then there will be multiple buckets.
         */
        private long getTotalBytesForNetwork(NetworkTemplate template) {
            NetworkStats stats = mCollection.getSummary(template,
                    Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
                    mAccessLevel, mCallingUid);
            return stats.getTotalBytes();
        }
    }

    private static class UserUsageRequestInfo extends RequestInfo {
        UserUsageRequestInfo(NetworkStatsObservers statsObserver, DataUsageRequest request,
                IUsageCallback callback, int callingPid, int callingUid,
                @NonNull String callingPackage, @NetworkStatsAccess.Level int accessLevel) {
            super(statsObserver, request, callback, callingPid, callingUid,
                    callingPackage, accessLevel);
        }

        @Override
        protected boolean checkStats() {
            int[] uidsToMonitor = mCollection.getRelevantUids(mAccessLevel, mCallingUid);

            for (int i = 0; i < uidsToMonitor.length; i++) {
                long bytesSoFar = getTotalBytesForNetworkUid(mRequest.template, uidsToMonitor[i]);
                if (bytesSoFar > mRequest.thresholdInBytes) {
                    return true;
                }
            }
            return false;
        }

        @Override
        protected void recordSample(StatsContext statsContext) {
            // Recorder does not need to be locked in this context since only the handler
            // thread will update it. We pass the VPN info so VPN traffic is reattributed to
            // responsible apps.
            mRecorder.recordSnapshotLocked(statsContext.mUidSnapshot, statsContext.mActiveUidIfaces,
                    statsContext.mCurrentTime);
        }

        /**
         * Reads all stats matching the given template and uid. Ther history will likely only
         * contain one bucket per ident since we build it big enough that it will outlive the
         * caller lifetime.
         */
        private long getTotalBytesForNetworkUid(NetworkTemplate template, int uid) {
            try {
                NetworkStatsHistory history = mCollection.getHistory(template, null, uid,
                        NetworkStats.SET_ALL, NetworkStats.TAG_NONE,
                        NetworkStatsHistory.FIELD_ALL,
                        Long.MIN_VALUE /* start */, Long.MAX_VALUE /* end */,
                        mAccessLevel, mCallingUid);
                return history.getTotalBytes();
            } catch (SecurityException e) {
                if (LOGV) {
                    Log.w(TAG, "CallerUid " + mCallingUid + " may have lost access to uid "
                            + uid);
                }
                return 0;
            }
        }
    }

    private static class StatsContext {
        NetworkStats mXtSnapshot;
        NetworkStats mUidSnapshot;
        ArrayMap<String, NetworkIdentitySet> mActiveIfaces;
        ArrayMap<String, NetworkIdentitySet> mActiveUidIfaces;
        long mCurrentTime;

        StatsContext(NetworkStats xtSnapshot, NetworkStats uidSnapshot,
                ArrayMap<String, NetworkIdentitySet> activeIfaces,
                ArrayMap<String, NetworkIdentitySet> activeUidIfaces,
                long currentTime) {
            mXtSnapshot = xtSnapshot;
            mUidSnapshot = uidSnapshot;
            mActiveIfaces = activeIfaces;
            mActiveUidIfaces = activeUidIfaces;
            mCurrentTime = currentTime;
        }
    }

    public void dump(IndentingPrintWriter pw) {
        for (int i = 0; i < Math.min(mDataUsageRequests.size(), DUMP_USAGE_REQUESTS_COUNT); i++) {
            pw.println(mDataUsageRequests.valueAt(i));
        }
    }
}
