blob: a9c23110c944a1b63d9e128bfdca86cce56c53de [file] [log] [blame]
Josh Gaob789fb12019-10-24 19:02:14 -07001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define ANDROID_BASE_UNIQUE_FD_DISABLE_IMPLICIT_CONVERSION
18
19#include "include/adbd_auth.h"
20
21#include <inttypes.h>
22#include <sys/epoll.h>
23#include <sys/eventfd.h>
24#include <sys/uio.h>
25
26#include <chrono>
27#include <deque>
28#include <string>
29#include <string_view>
30#include <tuple>
31#include <unordered_map>
32#include <utility>
33#include <variant>
34#include <vector>
35
36#include <android-base/file.h>
37#include <android-base/logging.h>
38#include <android-base/macros.h>
39#include <android-base/strings.h>
40#include <android-base/thread_annotations.h>
41#include <android-base/unique_fd.h>
42#include <cutils/sockets.h>
43
44using android::base::unique_fd;
45
46struct AdbdAuthPacketAuthenticated {
47 std::string public_key;
48};
49
50struct AdbdAuthPacketDisconnected {
51 std::string public_key;
52};
53
54struct AdbdAuthPacketRequestAuthorization {
55 std::string public_key;
56};
57
58using AdbdAuthPacket = std::variant<AdbdAuthPacketAuthenticated, AdbdAuthPacketDisconnected,
59 AdbdAuthPacketRequestAuthorization>;
60
61struct AdbdAuthContext {
62 static constexpr uint64_t kEpollConstSocket = 0;
63 static constexpr uint64_t kEpollConstEventFd = 1;
64 static constexpr uint64_t kEpollConstFramework = 2;
65
66public:
67 explicit AdbdAuthContext(AdbdAuthCallbacksV1* callbacks) : next_id_(0), callbacks_(*callbacks) {
68 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
69 if (epoll_fd_ == -1) {
70 PLOG(FATAL) << "failed to create epoll fd";
71 }
72
73 event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
74 if (event_fd_ == -1) {
75 PLOG(FATAL) << "failed to create eventfd";
76 }
77
78 sock_fd_.reset(android_get_control_socket("adbd"));
79 if (sock_fd_ == -1) {
80 PLOG(ERROR) << "failed to get adbd authentication socket";
81 } else {
82 if (fcntl(sock_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
83 PLOG(FATAL) << "failed to make adbd authentication socket cloexec";
84 }
85
86 if (fcntl(sock_fd_.get(), F_SETFL, O_NONBLOCK) != 0) {
87 PLOG(FATAL) << "failed to make adbd authentication socket nonblocking";
88 }
89
90 if (listen(sock_fd_.get(), 4) != 0) {
91 PLOG(FATAL) << "failed to listen on adbd authentication socket";
92 }
93 }
94 }
95
96 AdbdAuthContext(const AdbdAuthContext& copy) = delete;
97 AdbdAuthContext(AdbdAuthContext&& move) = delete;
98 AdbdAuthContext& operator=(const AdbdAuthContext& copy) = delete;
99 AdbdAuthContext& operator=(AdbdAuthContext&& move) = delete;
100
101 uint64_t NextId() { return next_id_++; }
102
103 void DispatchPendingPrompt() REQUIRES(mutex_) {
104 if (dispatched_prompt_) {
105 LOG(INFO) << "adbd_auth: prompt currently pending, skipping";
106 return;
107 }
108
109 if (pending_prompts_.empty()) {
110 LOG(INFO) << "adbd_auth: no prompts to send";
111 return;
112 }
113
114 LOG(INFO) << "adbd_auth: prompting user for adb authentication";
115 auto [id, public_key, arg] = std::move(pending_prompts_.front());
116 pending_prompts_.pop_front();
117
118 this->output_queue_.emplace_back(
119 AdbdAuthPacketRequestAuthorization{.public_key = public_key});
120
121 Interrupt();
122 dispatched_prompt_ = std::make_tuple(id, public_key, arg);
123 }
124
125 void UpdateFrameworkWritable() REQUIRES(mutex_) {
126 // This might result in redundant calls to EPOLL_CTL_MOD if, for example, we get notified
127 // at the same time as a framework connection, but that's unlikely and this doesn't need to
128 // be fast anyway.
129 if (framework_fd_ != -1) {
130 struct epoll_event event;
131 event.events = EPOLLIN;
132 if (!output_queue_.empty()) {
133 LOG(INFO) << "marking framework writable";
134 event.events |= EPOLLOUT;
135 }
136 event.data.u64 = kEpollConstFramework;
137 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_MOD, framework_fd_.get(), &event));
138 }
139 }
140
141 void ReplaceFrameworkFd(unique_fd new_fd) REQUIRES(mutex_) {
142 LOG(INFO) << "received new framework fd " << new_fd.get()
143 << " (current = " << framework_fd_.get() << ")";
144
145 // If we already had a framework fd, clean up after ourselves.
146 if (framework_fd_ != -1) {
147 output_queue_.clear();
148 dispatched_prompt_.reset();
149 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_DEL, framework_fd_.get(), nullptr));
150 framework_fd_.reset();
151 }
152
153 if (new_fd != -1) {
154 struct epoll_event event;
155 event.events = EPOLLIN;
156 if (!output_queue_.empty()) {
157 LOG(INFO) << "marking framework writable";
158 event.events |= EPOLLOUT;
159 }
160 event.data.u64 = kEpollConstFramework;
161 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, new_fd.get(), &event));
162 framework_fd_ = std::move(new_fd);
163 }
164 }
165
166 void HandlePacket(std::string_view packet) REQUIRES(mutex_) {
167 LOG(INFO) << "received packet: " << packet;
168
169 if (packet.length() < 2) {
170 LOG(ERROR) << "received packet of invalid length";
171 ReplaceFrameworkFd(unique_fd());
172 }
173
174 if (packet[0] == 'O' && packet[1] == 'K') {
175 CHECK(this->dispatched_prompt_.has_value());
176 auto& [id, key, arg] = *this->dispatched_prompt_;
177 keys_.emplace(id, std::move(key));
178
179 this->callbacks_.key_authorized(arg, id);
180 this->dispatched_prompt_ = std::nullopt;
Josh Gaobb927a12019-12-11 14:22:38 -0800181
182 // We need to dispatch pending prompts here upon success as well,
183 // since we might have multiple queued prompts.
184 DispatchPendingPrompt();
Josh Gaob789fb12019-10-24 19:02:14 -0700185 } else if (packet[0] == 'N' && packet[1] == 'O') {
186 CHECK_EQ(2UL, packet.length());
187 // TODO: Do we want a callback if the key is denied?
188 this->dispatched_prompt_ = std::nullopt;
189 DispatchPendingPrompt();
190 } else {
191 LOG(ERROR) << "unhandled packet: " << packet;
192 ReplaceFrameworkFd(unique_fd());
193 }
194 }
195
196 bool SendPacket() REQUIRES(mutex_) {
197 if (output_queue_.empty()) {
198 return false;
199 }
200
201 CHECK_NE(-1, framework_fd_.get());
202
203 auto& packet = output_queue_.front();
204 struct iovec iovs[2];
205 if (auto* p = std::get_if<AdbdAuthPacketAuthenticated>(&packet)) {
206 iovs[0].iov_base = const_cast<char*>("CK");
207 iovs[0].iov_len = 2;
208 iovs[1].iov_base = p->public_key.data();
209 iovs[1].iov_len = p->public_key.size();
210 } else if (auto* p = std::get_if<AdbdAuthPacketDisconnected>(&packet)) {
211 iovs[0].iov_base = const_cast<char*>("DC");
212 iovs[0].iov_len = 2;
213 iovs[1].iov_base = p->public_key.data();
214 iovs[1].iov_len = p->public_key.size();
215 } else if (auto* p = std::get_if<AdbdAuthPacketRequestAuthorization>(&packet)) {
216 iovs[0].iov_base = const_cast<char*>("PK");
217 iovs[0].iov_len = 2;
218 iovs[1].iov_base = p->public_key.data();
219 iovs[1].iov_len = p->public_key.size();
220 } else {
221 LOG(FATAL) << "unhandled packet type?";
222 }
223
224 output_queue_.pop_front();
225
226 ssize_t rc = writev(framework_fd_.get(), iovs, 2);
227 if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
228 PLOG(ERROR) << "failed to write to framework fd";
229 ReplaceFrameworkFd(unique_fd());
230 return false;
231 }
232
233 return true;
234 }
235
236 void Run() {
237 if (sock_fd_ == -1) {
238 LOG(ERROR) << "adbd authentication socket unavailable, disabling user prompts";
239 } else {
240 struct epoll_event event;
241 event.events = EPOLLIN;
242 event.data.u64 = kEpollConstSocket;
243 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, sock_fd_.get(), &event));
244 }
245
246 {
247 struct epoll_event event;
248 event.events = EPOLLIN;
249 event.data.u64 = kEpollConstEventFd;
250 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
251 }
252
253 while (true) {
254 struct epoll_event events[3];
255 int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 3, -1));
256 if (rc == -1) {
257 PLOG(FATAL) << "epoll_wait failed";
258 } else if (rc == 0) {
259 LOG(FATAL) << "epoll_wait returned 0";
260 }
261
262 bool restart = false;
263 for (int i = 0; i < rc; ++i) {
264 if (restart) {
265 break;
266 }
267
268 struct epoll_event& event = events[i];
269 switch (event.data.u64) {
270 case kEpollConstSocket: {
271 unique_fd new_framework_fd(accept4(sock_fd_.get(), nullptr, nullptr,
272 SOCK_CLOEXEC | SOCK_NONBLOCK));
273 if (new_framework_fd == -1) {
274 PLOG(FATAL) << "failed to accept framework fd";
275 }
276
277 LOG(INFO) << "adbd_auth: received a new framework connection";
278 std::lock_guard<std::mutex> lock(mutex_);
279 ReplaceFrameworkFd(std::move(new_framework_fd));
280
281 // Stop iterating over events: one of the later ones might be the old
282 // framework fd.
283 restart = false;
284 break;
285 }
286
287 case kEpollConstEventFd: {
288 // We were woken up to write something.
289 uint64_t dummy;
290 int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
291 if (rc != 8) {
292 PLOG(FATAL) << "failed to read from eventfd (rc = " << rc << ")";
293 }
294
295 std::lock_guard<std::mutex> lock(mutex_);
296 UpdateFrameworkWritable();
297 break;
298 }
299
300 case kEpollConstFramework: {
301 char buf[4096];
302 if (event.events & EPOLLIN) {
303 int rc = TEMP_FAILURE_RETRY(read(framework_fd_.get(), buf, sizeof(buf)));
304 if (rc == -1) {
305 LOG(FATAL) << "failed to read from framework fd";
306 } else if (rc == 0) {
307 LOG(INFO) << "hit EOF on framework fd";
308 std::lock_guard<std::mutex> lock(mutex_);
309 ReplaceFrameworkFd(unique_fd());
310 } else {
311 std::lock_guard<std::mutex> lock(mutex_);
312 HandlePacket(std::string_view(buf, rc));
313 }
314 }
315
316 if (event.events & EPOLLOUT) {
317 std::lock_guard<std::mutex> lock(mutex_);
318 while (SendPacket()) {
319 continue;
320 }
321 UpdateFrameworkWritable();
322 }
323
324 break;
325 }
326 }
327 }
328 }
329 }
330
331 static constexpr const char* key_paths[] = {"/adb_keys", "/data/misc/adb/adb_keys"};
332 void IteratePublicKeys(bool (*callback)(const char*, size_t, void*), void* arg) {
333 for (const auto& path : key_paths) {
334 if (access(path, R_OK) == 0) {
335 LOG(INFO) << "Loading keys from " << path;
336 std::string content;
337 if (!android::base::ReadFileToString(path, &content)) {
338 PLOG(ERROR) << "Couldn't read " << path;
339 continue;
340 }
341 for (const auto& line : android::base::Split(content, "\n")) {
342 if (!callback(line.data(), line.size(), arg)) {
343 return;
344 }
345 }
346 }
347 }
348 }
349
350 uint64_t PromptUser(std::string_view public_key, void* arg) EXCLUDES(mutex_) {
351 uint64_t id = NextId();
352
353 std::lock_guard<std::mutex> lock(mutex_);
354 pending_prompts_.emplace_back(id, public_key, arg);
355 DispatchPendingPrompt();
356 return id;
357 }
358
359 uint64_t NotifyAuthenticated(std::string_view public_key) EXCLUDES(mutex_) {
360 uint64_t id = NextId();
361 std::lock_guard<std::mutex> lock(mutex_);
362 keys_.emplace(id, public_key);
363 output_queue_.emplace_back(
364 AdbdAuthPacketDisconnected{.public_key = std::string(public_key)});
365 return id;
366 }
367
368 void NotifyDisconnected(uint64_t id) EXCLUDES(mutex_) {
369 std::lock_guard<std::mutex> lock(mutex_);
370 auto it = keys_.find(id);
371 if (it == keys_.end()) {
372 LOG(DEBUG) << "couldn't find public key to notify disconnection, skipping";
373 return;
374 }
375 output_queue_.emplace_back(AdbdAuthPacketDisconnected{.public_key = std::move(it->second)});
376 keys_.erase(it);
377 }
378
379 // Interrupt the worker thread to do some work.
380 void Interrupt() {
381 uint64_t value = 1;
382 ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
383 if (rc == -1) {
384 PLOG(FATAL) << "write to eventfd failed";
385 } else if (rc != sizeof(value)) {
386 LOG(FATAL) << "write to eventfd returned short (" << rc << ")";
387 }
388 }
389
390 unique_fd epoll_fd_;
391 unique_fd event_fd_;
392 unique_fd sock_fd_;
393 unique_fd framework_fd_;
394
395 std::atomic<uint64_t> next_id_;
396 AdbdAuthCallbacksV1 callbacks_;
397
398 std::mutex mutex_;
399 std::unordered_map<uint64_t, std::string> keys_ GUARDED_BY(mutex_);
400
401 // We keep two separate queues: one to handle backpressure from the socket (output_queue_)
402 // and one to make sure we only dispatch one authrequest at a time (pending_prompts_).
403 std::deque<AdbdAuthPacket> output_queue_;
404
405 std::optional<std::tuple<uint64_t, std::string, void*>> dispatched_prompt_ GUARDED_BY(mutex_);
406 std::deque<std::tuple<uint64_t, std::string, void*>> pending_prompts_ GUARDED_BY(mutex_);
407};
408
409AdbdAuthContext* adbd_auth_new(AdbdAuthCallbacks* callbacks) {
410 if (callbacks->version != 1) {
411 LOG(ERROR) << "received unknown AdbdAuthCallbacks version " << callbacks->version;
412 return nullptr;
413 }
414
415 return new AdbdAuthContext(&callbacks->callbacks.v1);
416}
417
418void adbd_auth_delete(AdbdAuthContext* ctx) {
419 delete ctx;
420}
421
422void adbd_auth_run(AdbdAuthContext* ctx) {
423 return ctx->Run();
424}
425
426void adbd_auth_get_public_keys(AdbdAuthContext* ctx,
427 bool (*callback)(const char* public_key, size_t len, void* arg),
428 void* arg) {
429 ctx->IteratePublicKeys(callback, arg);
430}
431
432uint64_t adbd_auth_notify_auth(AdbdAuthContext* ctx, const char* public_key, size_t len) {
433 return ctx->NotifyAuthenticated(std::string_view(public_key, len));
434}
435
436void adbd_auth_notify_disconnect(AdbdAuthContext* ctx, uint64_t id) {
437 return ctx->NotifyDisconnected(id);
438}
439
440void adbd_auth_prompt_user(AdbdAuthContext* ctx, const char* public_key, size_t len,
441 void* arg) {
442 ctx->PromptUser(std::string_view(public_key, len), arg);
443}
444
445bool adbd_auth_supports_feature(AdbdAuthFeature) {
446 return false;
447}