blob: 64791098ee6da60d34198b4ae39e93d7bc2e298f [file] [log] [blame]
Josh Gaob789fb12019-10-24 19:02:14 -07001/*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#define ANDROID_BASE_UNIQUE_FD_DISABLE_IMPLICIT_CONVERSION
18
19#include "include/adbd_auth.h"
20
21#include <inttypes.h>
22#include <sys/epoll.h>
23#include <sys/eventfd.h>
24#include <sys/uio.h>
25
26#include <chrono>
27#include <deque>
28#include <string>
29#include <string_view>
30#include <tuple>
31#include <unordered_map>
32#include <utility>
33#include <variant>
34#include <vector>
35
36#include <android-base/file.h>
37#include <android-base/logging.h>
38#include <android-base/macros.h>
39#include <android-base/strings.h>
40#include <android-base/thread_annotations.h>
41#include <android-base/unique_fd.h>
42#include <cutils/sockets.h>
43
44using android::base::unique_fd;
45
46struct AdbdAuthPacketAuthenticated {
47 std::string public_key;
48};
49
50struct AdbdAuthPacketDisconnected {
51 std::string public_key;
52};
53
54struct AdbdAuthPacketRequestAuthorization {
55 std::string public_key;
56};
57
58using AdbdAuthPacket = std::variant<AdbdAuthPacketAuthenticated, AdbdAuthPacketDisconnected,
59 AdbdAuthPacketRequestAuthorization>;
60
61struct AdbdAuthContext {
62 static constexpr uint64_t kEpollConstSocket = 0;
63 static constexpr uint64_t kEpollConstEventFd = 1;
64 static constexpr uint64_t kEpollConstFramework = 2;
65
66public:
67 explicit AdbdAuthContext(AdbdAuthCallbacksV1* callbacks) : next_id_(0), callbacks_(*callbacks) {
68 epoll_fd_.reset(epoll_create1(EPOLL_CLOEXEC));
69 if (epoll_fd_ == -1) {
70 PLOG(FATAL) << "failed to create epoll fd";
71 }
72
73 event_fd_.reset(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK));
74 if (event_fd_ == -1) {
75 PLOG(FATAL) << "failed to create eventfd";
76 }
77
78 sock_fd_.reset(android_get_control_socket("adbd"));
79 if (sock_fd_ == -1) {
80 PLOG(ERROR) << "failed to get adbd authentication socket";
81 } else {
82 if (fcntl(sock_fd_.get(), F_SETFD, FD_CLOEXEC) != 0) {
83 PLOG(FATAL) << "failed to make adbd authentication socket cloexec";
84 }
85
86 if (fcntl(sock_fd_.get(), F_SETFL, O_NONBLOCK) != 0) {
87 PLOG(FATAL) << "failed to make adbd authentication socket nonblocking";
88 }
89
90 if (listen(sock_fd_.get(), 4) != 0) {
91 PLOG(FATAL) << "failed to listen on adbd authentication socket";
92 }
93 }
94 }
95
96 AdbdAuthContext(const AdbdAuthContext& copy) = delete;
97 AdbdAuthContext(AdbdAuthContext&& move) = delete;
98 AdbdAuthContext& operator=(const AdbdAuthContext& copy) = delete;
99 AdbdAuthContext& operator=(AdbdAuthContext&& move) = delete;
100
101 uint64_t NextId() { return next_id_++; }
102
103 void DispatchPendingPrompt() REQUIRES(mutex_) {
104 if (dispatched_prompt_) {
105 LOG(INFO) << "adbd_auth: prompt currently pending, skipping";
106 return;
107 }
108
109 if (pending_prompts_.empty()) {
110 LOG(INFO) << "adbd_auth: no prompts to send";
111 return;
112 }
113
114 LOG(INFO) << "adbd_auth: prompting user for adb authentication";
115 auto [id, public_key, arg] = std::move(pending_prompts_.front());
116 pending_prompts_.pop_front();
117
118 this->output_queue_.emplace_back(
119 AdbdAuthPacketRequestAuthorization{.public_key = public_key});
120
121 Interrupt();
122 dispatched_prompt_ = std::make_tuple(id, public_key, arg);
123 }
124
125 void UpdateFrameworkWritable() REQUIRES(mutex_) {
126 // This might result in redundant calls to EPOLL_CTL_MOD if, for example, we get notified
127 // at the same time as a framework connection, but that's unlikely and this doesn't need to
128 // be fast anyway.
129 if (framework_fd_ != -1) {
130 struct epoll_event event;
131 event.events = EPOLLIN;
132 if (!output_queue_.empty()) {
133 LOG(INFO) << "marking framework writable";
134 event.events |= EPOLLOUT;
135 }
136 event.data.u64 = kEpollConstFramework;
137 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_MOD, framework_fd_.get(), &event));
138 }
139 }
140
141 void ReplaceFrameworkFd(unique_fd new_fd) REQUIRES(mutex_) {
142 LOG(INFO) << "received new framework fd " << new_fd.get()
143 << " (current = " << framework_fd_.get() << ")";
144
145 // If we already had a framework fd, clean up after ourselves.
146 if (framework_fd_ != -1) {
147 output_queue_.clear();
148 dispatched_prompt_.reset();
149 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_DEL, framework_fd_.get(), nullptr));
150 framework_fd_.reset();
151 }
152
153 if (new_fd != -1) {
154 struct epoll_event event;
155 event.events = EPOLLIN;
156 if (!output_queue_.empty()) {
157 LOG(INFO) << "marking framework writable";
158 event.events |= EPOLLOUT;
159 }
160 event.data.u64 = kEpollConstFramework;
161 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, new_fd.get(), &event));
162 framework_fd_ = std::move(new_fd);
163 }
164 }
165
166 void HandlePacket(std::string_view packet) REQUIRES(mutex_) {
167 LOG(INFO) << "received packet: " << packet;
168
169 if (packet.length() < 2) {
170 LOG(ERROR) << "received packet of invalid length";
171 ReplaceFrameworkFd(unique_fd());
172 }
173
174 if (packet[0] == 'O' && packet[1] == 'K') {
175 CHECK(this->dispatched_prompt_.has_value());
176 auto& [id, key, arg] = *this->dispatched_prompt_;
177 keys_.emplace(id, std::move(key));
178
179 this->callbacks_.key_authorized(arg, id);
180 this->dispatched_prompt_ = std::nullopt;
181 } else if (packet[0] == 'N' && packet[1] == 'O') {
182 CHECK_EQ(2UL, packet.length());
183 // TODO: Do we want a callback if the key is denied?
184 this->dispatched_prompt_ = std::nullopt;
185 DispatchPendingPrompt();
186 } else {
187 LOG(ERROR) << "unhandled packet: " << packet;
188 ReplaceFrameworkFd(unique_fd());
189 }
190 }
191
192 bool SendPacket() REQUIRES(mutex_) {
193 if (output_queue_.empty()) {
194 return false;
195 }
196
197 CHECK_NE(-1, framework_fd_.get());
198
199 auto& packet = output_queue_.front();
200 struct iovec iovs[2];
201 if (auto* p = std::get_if<AdbdAuthPacketAuthenticated>(&packet)) {
202 iovs[0].iov_base = const_cast<char*>("CK");
203 iovs[0].iov_len = 2;
204 iovs[1].iov_base = p->public_key.data();
205 iovs[1].iov_len = p->public_key.size();
206 } else if (auto* p = std::get_if<AdbdAuthPacketDisconnected>(&packet)) {
207 iovs[0].iov_base = const_cast<char*>("DC");
208 iovs[0].iov_len = 2;
209 iovs[1].iov_base = p->public_key.data();
210 iovs[1].iov_len = p->public_key.size();
211 } else if (auto* p = std::get_if<AdbdAuthPacketRequestAuthorization>(&packet)) {
212 iovs[0].iov_base = const_cast<char*>("PK");
213 iovs[0].iov_len = 2;
214 iovs[1].iov_base = p->public_key.data();
215 iovs[1].iov_len = p->public_key.size();
216 } else {
217 LOG(FATAL) << "unhandled packet type?";
218 }
219
220 output_queue_.pop_front();
221
222 ssize_t rc = writev(framework_fd_.get(), iovs, 2);
223 if (rc == -1 && errno != EAGAIN && errno != EWOULDBLOCK) {
224 PLOG(ERROR) << "failed to write to framework fd";
225 ReplaceFrameworkFd(unique_fd());
226 return false;
227 }
228
229 return true;
230 }
231
232 void Run() {
233 if (sock_fd_ == -1) {
234 LOG(ERROR) << "adbd authentication socket unavailable, disabling user prompts";
235 } else {
236 struct epoll_event event;
237 event.events = EPOLLIN;
238 event.data.u64 = kEpollConstSocket;
239 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, sock_fd_.get(), &event));
240 }
241
242 {
243 struct epoll_event event;
244 event.events = EPOLLIN;
245 event.data.u64 = kEpollConstEventFd;
246 CHECK_EQ(0, epoll_ctl(epoll_fd_.get(), EPOLL_CTL_ADD, event_fd_.get(), &event));
247 }
248
249 while (true) {
250 struct epoll_event events[3];
251 int rc = TEMP_FAILURE_RETRY(epoll_wait(epoll_fd_.get(), events, 3, -1));
252 if (rc == -1) {
253 PLOG(FATAL) << "epoll_wait failed";
254 } else if (rc == 0) {
255 LOG(FATAL) << "epoll_wait returned 0";
256 }
257
258 bool restart = false;
259 for (int i = 0; i < rc; ++i) {
260 if (restart) {
261 break;
262 }
263
264 struct epoll_event& event = events[i];
265 switch (event.data.u64) {
266 case kEpollConstSocket: {
267 unique_fd new_framework_fd(accept4(sock_fd_.get(), nullptr, nullptr,
268 SOCK_CLOEXEC | SOCK_NONBLOCK));
269 if (new_framework_fd == -1) {
270 PLOG(FATAL) << "failed to accept framework fd";
271 }
272
273 LOG(INFO) << "adbd_auth: received a new framework connection";
274 std::lock_guard<std::mutex> lock(mutex_);
275 ReplaceFrameworkFd(std::move(new_framework_fd));
276
277 // Stop iterating over events: one of the later ones might be the old
278 // framework fd.
279 restart = false;
280 break;
281 }
282
283 case kEpollConstEventFd: {
284 // We were woken up to write something.
285 uint64_t dummy;
286 int rc = TEMP_FAILURE_RETRY(read(event_fd_.get(), &dummy, sizeof(dummy)));
287 if (rc != 8) {
288 PLOG(FATAL) << "failed to read from eventfd (rc = " << rc << ")";
289 }
290
291 std::lock_guard<std::mutex> lock(mutex_);
292 UpdateFrameworkWritable();
293 break;
294 }
295
296 case kEpollConstFramework: {
297 char buf[4096];
298 if (event.events & EPOLLIN) {
299 int rc = TEMP_FAILURE_RETRY(read(framework_fd_.get(), buf, sizeof(buf)));
300 if (rc == -1) {
301 LOG(FATAL) << "failed to read from framework fd";
302 } else if (rc == 0) {
303 LOG(INFO) << "hit EOF on framework fd";
304 std::lock_guard<std::mutex> lock(mutex_);
305 ReplaceFrameworkFd(unique_fd());
306 } else {
307 std::lock_guard<std::mutex> lock(mutex_);
308 HandlePacket(std::string_view(buf, rc));
309 }
310 }
311
312 if (event.events & EPOLLOUT) {
313 std::lock_guard<std::mutex> lock(mutex_);
314 while (SendPacket()) {
315 continue;
316 }
317 UpdateFrameworkWritable();
318 }
319
320 break;
321 }
322 }
323 }
324 }
325 }
326
327 static constexpr const char* key_paths[] = {"/adb_keys", "/data/misc/adb/adb_keys"};
328 void IteratePublicKeys(bool (*callback)(const char*, size_t, void*), void* arg) {
329 for (const auto& path : key_paths) {
330 if (access(path, R_OK) == 0) {
331 LOG(INFO) << "Loading keys from " << path;
332 std::string content;
333 if (!android::base::ReadFileToString(path, &content)) {
334 PLOG(ERROR) << "Couldn't read " << path;
335 continue;
336 }
337 for (const auto& line : android::base::Split(content, "\n")) {
338 if (!callback(line.data(), line.size(), arg)) {
339 return;
340 }
341 }
342 }
343 }
344 }
345
346 uint64_t PromptUser(std::string_view public_key, void* arg) EXCLUDES(mutex_) {
347 uint64_t id = NextId();
348
349 std::lock_guard<std::mutex> lock(mutex_);
350 pending_prompts_.emplace_back(id, public_key, arg);
351 DispatchPendingPrompt();
352 return id;
353 }
354
355 uint64_t NotifyAuthenticated(std::string_view public_key) EXCLUDES(mutex_) {
356 uint64_t id = NextId();
357 std::lock_guard<std::mutex> lock(mutex_);
358 keys_.emplace(id, public_key);
359 output_queue_.emplace_back(
360 AdbdAuthPacketDisconnected{.public_key = std::string(public_key)});
361 return id;
362 }
363
364 void NotifyDisconnected(uint64_t id) EXCLUDES(mutex_) {
365 std::lock_guard<std::mutex> lock(mutex_);
366 auto it = keys_.find(id);
367 if (it == keys_.end()) {
368 LOG(DEBUG) << "couldn't find public key to notify disconnection, skipping";
369 return;
370 }
371 output_queue_.emplace_back(AdbdAuthPacketDisconnected{.public_key = std::move(it->second)});
372 keys_.erase(it);
373 }
374
375 // Interrupt the worker thread to do some work.
376 void Interrupt() {
377 uint64_t value = 1;
378 ssize_t rc = write(event_fd_.get(), &value, sizeof(value));
379 if (rc == -1) {
380 PLOG(FATAL) << "write to eventfd failed";
381 } else if (rc != sizeof(value)) {
382 LOG(FATAL) << "write to eventfd returned short (" << rc << ")";
383 }
384 }
385
386 unique_fd epoll_fd_;
387 unique_fd event_fd_;
388 unique_fd sock_fd_;
389 unique_fd framework_fd_;
390
391 std::atomic<uint64_t> next_id_;
392 AdbdAuthCallbacksV1 callbacks_;
393
394 std::mutex mutex_;
395 std::unordered_map<uint64_t, std::string> keys_ GUARDED_BY(mutex_);
396
397 // We keep two separate queues: one to handle backpressure from the socket (output_queue_)
398 // and one to make sure we only dispatch one authrequest at a time (pending_prompts_).
399 std::deque<AdbdAuthPacket> output_queue_;
400
401 std::optional<std::tuple<uint64_t, std::string, void*>> dispatched_prompt_ GUARDED_BY(mutex_);
402 std::deque<std::tuple<uint64_t, std::string, void*>> pending_prompts_ GUARDED_BY(mutex_);
403};
404
405AdbdAuthContext* adbd_auth_new(AdbdAuthCallbacks* callbacks) {
406 if (callbacks->version != 1) {
407 LOG(ERROR) << "received unknown AdbdAuthCallbacks version " << callbacks->version;
408 return nullptr;
409 }
410
411 return new AdbdAuthContext(&callbacks->callbacks.v1);
412}
413
414void adbd_auth_delete(AdbdAuthContext* ctx) {
415 delete ctx;
416}
417
418void adbd_auth_run(AdbdAuthContext* ctx) {
419 return ctx->Run();
420}
421
422void adbd_auth_get_public_keys(AdbdAuthContext* ctx,
423 bool (*callback)(const char* public_key, size_t len, void* arg),
424 void* arg) {
425 ctx->IteratePublicKeys(callback, arg);
426}
427
428uint64_t adbd_auth_notify_auth(AdbdAuthContext* ctx, const char* public_key, size_t len) {
429 return ctx->NotifyAuthenticated(std::string_view(public_key, len));
430}
431
432void adbd_auth_notify_disconnect(AdbdAuthContext* ctx, uint64_t id) {
433 return ctx->NotifyDisconnected(id);
434}
435
436void adbd_auth_prompt_user(AdbdAuthContext* ctx, const char* public_key, size_t len,
437 void* arg) {
438 ctx->PromptUser(std::string_view(public_key, len), arg);
439}
440
441bool adbd_auth_supports_feature(AdbdAuthFeature) {
442 return false;
443}