blob: 63a28a4e1159f4c9c08a7fcd2d01417ec368f2f1 [file] [log] [blame]
David Drysdale79af2662024-02-19 14:50:31 +00001// Copyright 2024, The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Database of VM IDs.
16
17use anyhow::{Context, Result};
18use log::{debug, error, info, warn};
19use rusqlite::{params, params_from_iter, Connection, OpenFlags, Rows};
20use std::path::PathBuf;
21
22/// Subdirectory to hold the database.
23const DB_DIR: &str = "vmdb";
24
25/// Name of the file that holds the database.
26const DB_FILENAME: &str = "vmids.sqlite";
27
28/// Maximum number of host parameters in a single SQL statement.
29/// (Default value of `SQLITE_LIMIT_VARIABLE_NUMBER` for <= 3.32.0)
30const MAX_VARIABLES: usize = 999;
31
32/// Identifier for a VM and its corresponding secret.
33pub type VmId = [u8; 64];
34
35/// Representation of an on-disk database of VM IDs.
36pub struct VmIdDb {
37 conn: Connection,
38}
39
40impl VmIdDb {
41 /// Connect to the VM ID database file held in the given directory, creating it if necessary.
42 /// The second return value indicates whether a new database file was created.
43 ///
44 /// This function assumes no other threads/processes are attempting to connect concurrently.
45 pub fn new(db_dir: &str) -> Result<(Self, bool)> {
46 let mut db_path = PathBuf::from(db_dir);
47 db_path.push(DB_DIR);
48 if !db_path.exists() {
49 std::fs::create_dir(&db_path).context("failed to create {db_path:?}")?;
50 info!("created persistent db dir {db_path:?}");
51 }
52
53 db_path.push(DB_FILENAME);
54 let (flags, created) = if db_path.exists() {
55 debug!("connecting to existing database {db_path:?}");
56 (
57 OpenFlags::SQLITE_OPEN_READ_WRITE
58 | OpenFlags::SQLITE_OPEN_URI
59 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
60 false,
61 )
62 } else {
63 info!("creating fresh database {db_path:?}");
64 (
65 OpenFlags::SQLITE_OPEN_READ_WRITE
66 | OpenFlags::SQLITE_OPEN_CREATE
67 | OpenFlags::SQLITE_OPEN_URI
68 | OpenFlags::SQLITE_OPEN_NO_MUTEX,
69 true,
70 )
71 };
72 let mut result = Self {
73 conn: Connection::open_with_flags(db_path, flags)
74 .context(format!("failed to open/create DB with {flags:?}"))?,
75 };
76
77 if created {
78 result.init_tables().context("failed to create tables")?;
79 }
80 Ok((result, created))
81 }
82
83 /// Delete the associated database file.
84 pub fn delete_db_file(self, db_dir: &str) {
85 let mut db_path = PathBuf::from(db_dir);
86 db_path.push(DB_DIR);
87 db_path.push(DB_FILENAME);
88
89 // Drop the connection before removing the backing file.
90 drop(self);
91 warn!("removing database file {db_path:?}");
92 if let Err(e) = std::fs::remove_file(&db_path) {
93 error!("failed to remove database file {db_path:?}: {e:?}");
94 }
95 }
96
97 /// Create the database table and indices.
98 fn init_tables(&mut self) -> Result<()> {
99 self.conn
100 .execute(
101 "CREATE TABLE IF NOT EXISTS main.vmids (
102 vm_id BLOB PRIMARY KEY,
103 user_id INTEGER,
104 app_id INTEGER
105 ) WITHOUT ROWID;",
106 (),
107 )
108 .context("failed to create table")?;
109 self.conn
110 .execute("CREATE INDEX IF NOT EXISTS main.vmids_user_index ON vmids(user_id);", [])
111 .context("Failed to create user index")?;
112 self.conn
113 .execute(
114 "CREATE INDEX IF NOT EXISTS main.vmids_app_index ON vmids(user_id, app_id);",
115 [],
116 )
117 .context("Failed to create app index")?;
118 Ok(())
119 }
120
121 /// Add the given VM ID into the database.
David Drysdale79af2662024-02-19 14:50:31 +0000122 pub fn add_vm_id(&mut self, vm_id: &VmId, user_id: i32, app_id: i32) -> Result<()> {
123 let _rows = self
124 .conn
125 .execute(
126 "REPLACE INTO main.vmids (vm_id, user_id, app_id) VALUES (?1, ?2, ?3);",
127 params![vm_id, &user_id, &app_id],
128 )
129 .context("failed to add VM ID")?;
130 Ok(())
131 }
132
133 /// Remove the given VM IDs from the database. The collection of IDs is assumed to be smaller
134 /// than the maximum number of SQLite parameters.
135 pub fn delete_vm_ids(&mut self, vm_ids: &[VmId]) -> Result<()> {
136 assert!(vm_ids.len() < MAX_VARIABLES);
137 let mut vars = "?,".repeat(vm_ids.len());
138 vars.pop(); // remove trailing comma
139 let sql = format!("DELETE FROM main.vmids WHERE vm_id IN ({});", vars);
140 let mut stmt = self.conn.prepare(&sql).context("failed to prepare DELETE stmt")?;
141 let _rows = stmt.execute(params_from_iter(vm_ids)).context("failed to delete VM IDs")?;
142 Ok(())
143 }
144
145 /// Return the VM IDs associated with Android user ID `user_id`.
146 pub fn vm_ids_for_user(&mut self, user_id: i32) -> Result<Vec<VmId>> {
147 let mut stmt = self
148 .conn
149 .prepare("SELECT vm_id FROM main.vmids WHERE user_id = ?;")
150 .context("failed to prepare SELECT stmt")?;
151 let rows = stmt.query(params![user_id]).context("query failed")?;
152 Self::vm_ids_from_rows(rows)
153 }
154
155 /// Return the VM IDs associated with `(user_id, app_id)`.
156 pub fn vm_ids_for_app(&mut self, user_id: i32, app_id: i32) -> Result<Vec<VmId>> {
157 let mut stmt = self
158 .conn
159 .prepare("SELECT vm_id FROM main.vmids WHERE user_id = ? AND app_id = ?;")
160 .context("failed to prepare SELECT stmt")?;
161 let rows = stmt.query(params![user_id, app_id]).context("query failed")?;
162 Self::vm_ids_from_rows(rows)
163 }
164
165 /// Retrieve a collection of VM IDs from database rows.
166 fn vm_ids_from_rows(mut rows: Rows) -> Result<Vec<VmId>> {
167 let mut vm_ids: Vec<VmId> = Vec::new();
168 while let Some(row) = rows.next().context("failed row unpack")? {
169 match row.get(0) {
170 Ok(vm_id) => vm_ids.push(vm_id),
171 Err(e) => log::error!("failed to parse row: {e:?}"),
172 }
173 }
174
175 Ok(vm_ids)
176 }
177}
178
179#[cfg(test)]
180pub fn new_test_db() -> VmIdDb {
181 let mut db = VmIdDb { conn: Connection::open_in_memory().unwrap() };
182 db.init_tables().unwrap();
183 db
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 const VM_ID1: VmId = [1u8; 64];
190 const VM_ID2: VmId = [2u8; 64];
191 const VM_ID3: VmId = [3u8; 64];
192 const VM_ID4: VmId = [4u8; 64];
193 const VM_ID5: VmId = [5u8; 64];
194 const USER1: i32 = 1;
195 const USER2: i32 = 2;
196 const USER3: i32 = 3;
197 const USER_UNKNOWN: i32 = 4;
198 const APP_A: i32 = 50;
199 const APP_B: i32 = 60;
200 const APP_C: i32 = 70;
201 const APP_UNKNOWN: i32 = 99;
202
203 #[test]
204 fn test_add_remove() {
205 let mut db = new_test_db();
206 db.add_vm_id(&VM_ID1, USER1, APP_A).unwrap();
207 db.add_vm_id(&VM_ID2, USER1, APP_A).unwrap();
208 db.add_vm_id(&VM_ID3, USER1, APP_A).unwrap();
209 db.add_vm_id(&VM_ID4, USER2, APP_B).unwrap();
210 db.add_vm_id(&VM_ID5, USER3, APP_A).unwrap();
211 db.add_vm_id(&VM_ID5, USER3, APP_C).unwrap();
212 let empty: Vec<VmId> = Vec::new();
213
214 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_user(USER1).unwrap());
215 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_app(USER1, APP_A).unwrap());
216 assert_eq!(vec![VM_ID4], db.vm_ids_for_app(USER2, APP_B).unwrap());
217 assert_eq!(vec![VM_ID5], db.vm_ids_for_user(USER3).unwrap());
218 assert_eq!(empty, db.vm_ids_for_user(USER_UNKNOWN).unwrap());
219 assert_eq!(empty, db.vm_ids_for_app(USER1, APP_UNKNOWN).unwrap());
220
221 db.delete_vm_ids(&[VM_ID2, VM_ID3]).unwrap();
222
223 assert_eq!(vec![VM_ID1], db.vm_ids_for_user(USER1).unwrap());
224 assert_eq!(vec![VM_ID1], db.vm_ids_for_app(USER1, APP_A).unwrap());
225
226 // OK to delete things that don't exist.
227 db.delete_vm_ids(&[VM_ID2, VM_ID3]).unwrap();
228
229 assert_eq!(vec![VM_ID1], db.vm_ids_for_user(USER1).unwrap());
230 assert_eq!(vec![VM_ID1], db.vm_ids_for_app(USER1, APP_A).unwrap());
231
232 db.add_vm_id(&VM_ID2, USER1, APP_A).unwrap();
233 db.add_vm_id(&VM_ID3, USER1, APP_A).unwrap();
234
235 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_user(USER1).unwrap());
236 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_app(USER1, APP_A).unwrap());
237 assert_eq!(vec![VM_ID4], db.vm_ids_for_app(USER2, APP_B).unwrap());
238 assert_eq!(vec![VM_ID5], db.vm_ids_for_user(USER3).unwrap());
239 assert_eq!(empty, db.vm_ids_for_user(USER_UNKNOWN).unwrap());
240 assert_eq!(empty, db.vm_ids_for_app(USER1, APP_UNKNOWN).unwrap());
241 }
242
243 #[test]
244 fn test_invalid_vm_id() {
245 let mut db = new_test_db();
246 db.add_vm_id(&VM_ID3, USER1, APP_A).unwrap();
247 db.add_vm_id(&VM_ID2, USER1, APP_A).unwrap();
248 db.add_vm_id(&VM_ID1, USER1, APP_A).unwrap();
249
250 // Note that results are returned in `vm_id` order, because the table is `WITHOUT ROWID`.
251 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_user(USER1).unwrap());
252
253 // Manually insert a row with a VM ID that's the wrong size.
254 db.conn
255 .execute(
256 "REPLACE INTO main.vmids (vm_id, user_id, app_id) VALUES (?1, ?2, ?3);",
257 params![&[99u8; 60], &USER1, APP_A],
258 )
259 .unwrap();
260
261 // Invalid row is skipped and remainder returned.
262 assert_eq!(vec![VM_ID1, VM_ID2, VM_ID3], db.vm_ids_for_user(USER1).unwrap());
263 }
264}