diff --git a/kinode/src/sqlite.rs b/kinode/src/sqlite.rs index 50dfaf85a..5cda7a0d9 100644 --- a/kinode/src/sqlite.rs +++ b/kinode/src/sqlite.rs @@ -4,8 +4,9 @@ use dashmap::DashMap; use lib::types::core::{ Address, CapMessage, CapMessageSender, Capability, FdManagerRequest, KernelMessage, LazyLoadBlob, Message, MessageReceiver, MessageSender, PackageId, PrintSender, Printout, - ProcessId, Request, Response, SqlValue, SqliteAction, SqliteError, SqliteRequest, - SqliteResponse, FD_MANAGER_PROCESS_ID, SQLITE_PROCESS_ID, + ProcessId, Request, Response, SqlValue, SqliteAction, SqliteCapabilityKind, + SqliteCapabilityParams, SqliteError, SqliteRequest, SqliteResponse, FD_MANAGER_PROCESS_ID, + SQLITE_PROCESS_ID, }; use rusqlite::Connection; use std::{ @@ -54,51 +55,46 @@ impl SqliteState { } } - pub async fn open_db(&mut self, package_id: PackageId, db: String) -> Result<(), SqliteError> { - let key = (package_id.clone(), db.clone()); - if self.open_dbs.contains_key(&key) { + pub async fn open_db(&mut self, key: &(PackageId, String)) -> Result<(), SqliteError> { + if self.open_dbs.contains_key(key) { let mut access_order = self.access_order.lock().await; - access_order.remove(&key); - access_order.push_back(key); + access_order.remove(key); + access_order.push_back(key.clone()); return Ok(()); } if self.open_dbs.len() as u64 >= self.fds_limit { // close least recently used db - let key = self.access_order.lock().await.pop_front().unwrap(); - self.remove_db(key.0, key.1).await; + let to_close = self.access_order.lock().await.pop_front().unwrap(); + self.remove_db(&to_close).await; } #[cfg(unix)] - let db_path = self.sqlite_path.join(format!("{package_id}")).join(&db); + let db_path = self.sqlite_path.join(format!("{}", key.0)).join(&key.1); #[cfg(target_os = "windows")] let db_path = self .sqlite_path - .join(format!( - "{}_{}", - package_id._package(), - package_id._publisher() - )) - .join(&db); + .join(format!("{}_{}", key.0._package(), key.0._publisher())) + .join(&key.1); fs::create_dir_all(&db_path).await?; - let db_file_path = db_path.join(format!("{}.db", db)); + let db_file_path = db_path.join(format!("{}.db", key.1)); let db_conn = Connection::open(db_file_path)?; let _: String = db_conn.query_row("PRAGMA journal_mode=WAL", [], |row| row.get(0))?; - self.open_dbs.insert(key, Mutex::new(db_conn)); + self.open_dbs.insert(key.clone(), Mutex::new(db_conn)); let mut access_order = self.access_order.lock().await; - access_order.push_back((package_id, db)); + access_order.push_back(key.clone()); Ok(()) } - pub async fn remove_db(&mut self, package_id: PackageId, db: String) { - self.open_dbs.remove(&(package_id.clone(), db.to_string())); + pub async fn remove_db(&mut self, key: &(PackageId, String)) { + self.open_dbs.remove(key); let mut access_order = self.access_order.lock().await; - access_order.remove(&(package_id, db)); + access_order.remove(key); } pub async fn remove_least_recently_used_dbs(&mut self, n: u64) { @@ -106,7 +102,7 @@ impl SqliteState { let mut lock = self.access_order.lock().await; let key = lock.pop_front().unwrap(); drop(lock); - self.remove_db(key.0, key.1).await; + self.remove_db(&key).await; } } } @@ -176,8 +172,7 @@ pub async fn sqlite( tokio::spawn(async move { let mut queue_lock = queue.lock().await; if let Some(km) = queue_lock.pop_front() { - let (km_id, km_rsvp) = - (km.id.clone(), km.rsvp.clone().unwrap_or(km.source.clone())); + let (km_id, km_rsvp) = (km.id, km.rsvp.clone().unwrap_or(km.source.clone())); if let Err(e) = handle_request(km, &mut state, &send_to_caps_oracle).await { Printout::new(1, SQLITE_PROCESS_ID.clone(), format!("sqlite: {e}")) @@ -226,27 +221,31 @@ async fn handle_request( .. }) = message else { - return Err(SqliteError::InputError { - error: "not a request".into(), - }); + // we got a response -- safe to ignore + return Ok(()); }; let request: SqliteRequest = match serde_json::from_slice(&body) { Ok(r) => r, Err(e) => { - println!("sqlite: got invalid Request: {}", e); - return Err(SqliteError::InputError { - error: "didn't serialize to SqliteRequest.".into(), - }); + println!("sqlite: got invalid request: {e}"); + return Err(SqliteError::MalformedRequest); } }; - check_caps(&source, state, send_to_caps_oracle, &request).await?; + let db_key = (request.package_id, request.db); + + check_caps( + &source, + state, + send_to_caps_oracle, + &request.action, + &db_key, + ) + .await?; // always open to ensure db exists - state - .open_db(request.package_id.clone(), request.db.clone()) - .await?; + state.open_db(&db_key).await?; let (body, bytes) = match request.action { SqliteAction::Open => { @@ -257,11 +256,11 @@ async fn handle_request( // handled in check_caps (serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None) } - SqliteAction::Read { query } => { - let db = match state.open_dbs.get(&(request.package_id, request.db)) { + SqliteAction::Query(query) => { + let db = match state.open_dbs.get(&db_key) { Some(db) => db, None => { - return Err(SqliteError::NoDb); + return Err(SqliteError::NoDb(db_key.0, db_key.1)); } }; let db = db.lock().await; @@ -314,10 +313,10 @@ async fn handle_request( ) } SqliteAction::Write { statement, tx_id } => { - let db = match state.open_dbs.get(&(request.package_id, request.db)) { + let db = match state.open_dbs.get(&db_key) { Some(db) => db, None => { - return Err(SqliteError::NoDb); + return Err(SqliteError::NoDb(db_key.0, db_key.1)); } }; let db = db.lock().await; @@ -359,17 +358,17 @@ async fn handle_request( ) } SqliteAction::Commit { tx_id } => { - let db = match state.open_dbs.get(&(request.package_id, request.db)) { + let db = match state.open_dbs.get(&db_key) { Some(db) => db, None => { - return Err(SqliteError::NoDb); + return Err(SqliteError::NoDb(db_key.0, db_key.1)); } }; let mut db = db.lock().await; let txs = match state.txs.remove(&tx_id).map(|(_, tx)| tx) { None => { - return Err(SqliteError::NoTx); + return Err(SqliteError::NoTx(tx_id)); } Some(tx) => tx, }; @@ -382,20 +381,6 @@ async fn handle_request( tx.commit()?; (serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None) } - SqliteAction::Backup => { - for db_ref in state.open_dbs.iter() { - let db = db_ref.value().lock().await; - let result: rusqlite::Result<()> = db - .query_row("PRAGMA wal_checkpoint(TRUNCATE)", [], |_| Ok(())) - .map(|_| ()); - if let Err(e) = result { - return Err(SqliteError::RusqliteError { - error: e.to_string(), - }); - } - } - (serde_json::to_vec(&SqliteResponse::Ok).unwrap(), None) - } }; if let Some(target) = km.rsvp.or_else(|| expects_response.map(|_| source)) { @@ -429,128 +414,110 @@ async fn check_caps( source: &Address, state: &mut SqliteState, send_to_caps_oracle: &CapMessageSender, - request: &SqliteRequest, + action: &SqliteAction, + db_key: &(PackageId, String), ) -> Result<(), SqliteError> { let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel(); let src_package_id = PackageId::new(source.process.package(), source.process.publisher()); - match &request.action { + match action { SqliteAction::Write { .. } | SqliteAction::BeginTx | SqliteAction::Commit { .. } => { - send_to_caps_oracle + let Ok(()) = send_to_caps_oracle .send(CapMessage::Has { on: source.process.clone(), cap: Capability::new( state.our.as_ref().clone(), - serde_json::json!({ - "kind": "write", - "db": request.db.to_string(), + serde_json::to_string(&SqliteCapabilityParams { + kind: SqliteCapabilityKind::Write, + db_key: db_key.clone(), }) - .to_string(), + .unwrap(), ), responder: send_cap_bool, }) - .await?; - let has_cap = recv_cap_bool.await?; - if !has_cap { - return Err(SqliteError::NoCap { - error: request.action.to_string(), - }); - } + .await + else { + return Err(SqliteError::AddCapFailed); + }; + let Ok(_) = recv_cap_bool.await else { + return Err(SqliteError::AddCapFailed); + }; Ok(()) } - SqliteAction::Read { .. } => { - send_to_caps_oracle + SqliteAction::Query { .. } => { + let Ok(()) = send_to_caps_oracle .send(CapMessage::Has { on: source.process.clone(), cap: Capability::new( state.our.as_ref().clone(), - serde_json::json!({ - "kind": "read", - "db": request.db.to_string(), + serde_json::to_string(&SqliteCapabilityParams { + kind: SqliteCapabilityKind::Read, + db_key: db_key.clone(), }) - .to_string(), + .unwrap(), ), responder: send_cap_bool, }) - .await?; - let has_cap = recv_cap_bool.await?; - if !has_cap { - return Err(SqliteError::NoCap { - error: request.action.to_string(), - }); - } + .await + else { + return Err(SqliteError::AddCapFailed); + }; + let Ok(_) = recv_cap_bool.await else { + return Err(SqliteError::AddCapFailed); + }; Ok(()) } SqliteAction::Open => { - if src_package_id != request.package_id { - return Err(SqliteError::NoCap { - error: request.action.to_string(), - }); + if src_package_id != db_key.0 { + return Err(SqliteError::MismatchingPackageId); } add_capability( - "read", - &request.db.to_string(), + SqliteCapabilityKind::Read, + db_key, &state.our, &source, send_to_caps_oracle, ) .await?; add_capability( - "write", - &request.db.to_string(), + SqliteCapabilityKind::Write, + db_key, &state.our, &source, send_to_caps_oracle, ) .await?; - if state - .open_dbs - .contains_key(&(request.package_id.clone(), request.db.clone())) - { + if state.open_dbs.contains_key(db_key) { return Ok(()); } - state - .open_db(request.package_id.clone(), request.db.clone()) - .await?; + state.open_db(db_key).await?; Ok(()) } SqliteAction::RemoveDb => { - if src_package_id != request.package_id { - return Err(SqliteError::NoCap { - error: request.action.to_string(), - }); + if src_package_id != db_key.0 { + return Err(SqliteError::MismatchingPackageId); } - state - .remove_db(request.package_id.clone(), request.db.clone()) - .await; + state.remove_db(db_key).await; #[cfg(unix)] let db_path = state .sqlite_path - .join(format!("{}", request.package_id)) - .join(&request.db); + .join(format!("{}", db_key.0)) + .join(&db_key.1); #[cfg(target_os = "windows")] let db_path = state .sqlite_path - .join(format!( - "{}_{}", - request.package_id._package(), - request.package_id._publisher() - )) - .join(&request.db); + .join(format!("{}_{}", db_key.0._package(), db_key.0._publisher())) + .join(&db_key.1); fs::remove_dir_all(&db_path).await?; Ok(()) } - SqliteAction::Backup => { - // flushing WALs for backup - Ok(()) - } } } @@ -559,9 +526,7 @@ async fn handle_fd_request(km: KernelMessage, state: &mut SqliteState) -> anyhow return Err(anyhow::anyhow!("not a request")); }; - let request: FdManagerRequest = serde_json::from_slice(&body)?; - - match request { + match serde_json::from_slice(&body)? { FdManagerRequest::FdsLimit(new_fds_limit) => { state.fds_limit = new_fds_limit; if state.open_dbs.len() as u64 >= state.fds_limit { @@ -581,25 +546,34 @@ async fn handle_fd_request(km: KernelMessage, state: &mut SqliteState) -> anyhow } async fn add_capability( - kind: &str, - db: &str, + kind: SqliteCapabilityKind, + db_key: &(PackageId, String), our: &Address, source: &Address, send_to_caps_oracle: &CapMessageSender, ) -> Result<(), SqliteError> { let cap = Capability { issuer: our.clone(), - params: serde_json::json!({ "kind": kind, "db": db }).to_string(), + params: serde_json::to_string(&SqliteCapabilityParams { + kind, + db_key: db_key.clone(), + }) + .unwrap(), }; let (send_cap_bool, recv_cap_bool) = tokio::sync::oneshot::channel(); - send_to_caps_oracle + let Ok(()) = send_to_caps_oracle .send(CapMessage::Add { on: source.process.clone(), caps: vec![cap], responder: Some(send_cap_bool), }) - .await?; - let _ = recv_cap_bool.await?; + .await + else { + return Err(SqliteError::AddCapFailed); + }; + let Ok(_) = recv_cap_bool.await else { + return Err(SqliteError::AddCapFailed); + }; Ok(()) } diff --git a/lib/src/sqlite.rs b/lib/src/sqlite.rs index 1e53ec5c4..90fc009ae 100644 --- a/lib/src/sqlite.rs +++ b/lib/src/sqlite.rs @@ -1,43 +1,117 @@ -use crate::types::core::{CapMessage, PackageId}; +use crate::types::core::PackageId; use rusqlite::types::{FromSql, FromSqlError, ToSql, ValueRef}; use serde::{Deserialize, Serialize}; use thiserror::Error; -/// IPC Request format for the sqlite:distro:sys runtime module. -#[derive(Debug, Serialize, Deserialize)] +/// Actions are sent to a specific SQLite database. `db` is the name, +/// `package_id` is the [`PackageId`] that created the database. Capabilities +/// are checked: you can access another process's database if it has given +/// you the read and/or write capability to do so. +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct SqliteRequest { pub package_id: PackageId, pub db: String, pub action: SqliteAction, } -#[derive(Debug, Serialize, Deserialize)] +/// IPC Action format representing operations that can be performed on the +/// SQLite runtime module. These actions are included in a [`SqliteRequest`] +/// sent to the `sqlite:distro:sys` runtime module. +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum SqliteAction { + /// Opens an existing key-value database or creates a new one if it doesn't exist. + /// Requires `package_id` in [`SqliteRequest`] to match the package ID of the sender. + /// The sender will own the database and can remove it with [`SqliteAction::RemoveDb`]. + /// + /// A successful open will respond with [`SqliteResponse::Ok`]. Any error will be + /// contained in the [`SqliteResponse::Err`] variant. Open, + /// Permanently deletes the entire key-value database. + /// Requires `package_id` in [`SqliteRequest`] to match the package ID of the sender. + /// Only the owner can remove the database. + /// + /// A successful remove will respond with [`SqliteResponse::Ok`]. Any error will be + /// contained in the [`SqliteResponse::Err`] variant. RemoveDb, + /// Executes a write statement (INSERT/UPDATE/DELETE) + /// + /// * `statement` - SQL statement to execute + /// * `tx_id` - Optional transaction ID + /// * blob: Vec - Parameters for the SQL statement, where SqlValue can be: + /// - null + /// - boolean + /// - i64 + /// - f64 + /// - String + /// - Vec (binary data) + /// + /// Using this action requires the sender to have the write capability + /// for the database. + /// + /// A successful write will respond with [`SqliteResponse::Ok`]. Any error will be + /// contained in the [`SqliteResponse::Err`] variant. Write { statement: String, tx_id: Option, }, - Read { - query: String, - }, + /// Executes a read query (SELECT) + /// + /// * blob: Vec - Parameters for the SQL query, where SqlValue can be: + /// - null + /// - boolean + /// - i64 + /// - f64 + /// - String + /// - Vec (binary data) + /// + /// Using this action requires the sender to have the read capability + /// for the database. + /// + /// A successful query will respond with [`SqliteResponse::Query`], where the + /// response blob contains the results of the query. Any error will be contained + /// in the [`SqliteResponse::Err`] variant. + Query(String), + /// Begins a new transaction for atomic operations. + /// + /// Sending this will prompt a [`SqliteResponse::BeginTx`] response with the + /// transaction ID. Any error will be contained in the [`SqliteResponse::Err`] variant. BeginTx, - Commit { - tx_id: u64, - }, - Backup, + /// Commits all operations in the specified transaction. + /// + /// # Parameters + /// * `tx_id` - The ID of the transaction to commit + /// + /// A successful commit will respond with [`SqliteResponse::Ok`]. Any error will be + /// contained in the [`SqliteResponse::Err`] variant. + Commit { tx_id: u64 }, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum SqliteResponse { + /// Indicates successful completion of an operation. + /// Sent in response to actions Open, RemoveDb, Write, Query, BeginTx, and Commit. Ok, + /// Returns the results of a query. + /// + /// * blob: Vec> - Array of rows, where each row contains SqlValue types: + /// - null + /// - boolean + /// - i64 + /// - f64 + /// - String + /// - Vec (binary data) Read, + /// Returns the transaction ID for a newly created transaction. + /// + /// # Fields + /// * `tx_id` - The ID of the newly created transaction BeginTx { tx_id: u64 }, + /// Indicates an error occurred during the operation. Err(SqliteError), } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +/// Used in blobs to represent array row values in SQLite. +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] pub enum SqlValue { Integer(i64), Real(f64), @@ -47,28 +121,50 @@ pub enum SqlValue { Null, } -#[derive(Debug, Serialize, Deserialize, Error)] +#[derive(Clone, Debug, Serialize, Deserialize, Error)] pub enum SqliteError { - #[error("sqlite: DbDoesNotExist")] - NoDb, - #[error("sqlite: NoTx")] - NoTx, - #[error("sqlite: No capability: {error}")] - NoCap { error: String }, - #[error("sqlite: UnexpectedResponse")] - UnexpectedResponse, - #[error("sqlite: NotAWriteKeyword")] + #[error("db [{0}, {1}] does not exist")] + NoDb(PackageId, String), + #[error("no transaction {0} found")] + NoTx(u64), + #[error("no write capability for requested DB")] + NoWriteCap, + #[error("no read capability for requested DB")] + NoReadCap, + #[error("request to open or remove DB with mismatching package ID")] + MismatchingPackageId, + #[error("failed to generate capability for new DB")] + AddCapFailed, + #[error("write statement started with non-existent write keyword")] NotAWriteKeyword, - #[error("sqlite: NotAReadKeyword")] + #[error("read query started with non-existent read keyword")] NotAReadKeyword, - #[error("sqlite: Invalid Parameters")] + #[error("parameters blob in read/write was misshapen or contained invalid JSON objects")] InvalidParameters, - #[error("sqlite: IO error: {error}")] - IOError { error: String }, - #[error("sqlite: rusqlite error: {error}")] - RusqliteError { error: String }, - #[error("sqlite: input bytes/json/key error: {error}")] - InputError { error: String }, + #[error("sqlite got a malformed request that failed to deserialize")] + MalformedRequest, + #[error("rusqlite error: {0}")] + RusqliteError(String), + #[error("IO error: {0}")] + IOError(String), +} + +/// The JSON parameters contained in all capabilities issued by `sqlite:distro:sys`. +/// +/// # Fields +/// * `kind` - The kind of capability, either [`SqliteCapabilityKind::Read`] or [`SqliteCapabilityKind::Write`] +/// * `db_key` - The database key, a tuple of the [`PackageId`] that created the database and the database name +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct SqliteCapabilityParams { + pub kind: SqliteCapabilityKind, + pub db_key: (PackageId, String), +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SqliteCapabilityKind { + Read, + Write, } impl ToSql for SqlValue { @@ -101,40 +197,14 @@ impl FromSql for SqlValue { } } -impl std::fmt::Display for SqliteAction { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - impl From for SqliteError { fn from(err: std::io::Error) -> Self { - SqliteError::IOError { - error: err.to_string(), - } + SqliteError::IOError(err.to_string()) } } impl From for SqliteError { fn from(err: rusqlite::Error) -> Self { - SqliteError::RusqliteError { - error: err.to_string(), - } - } -} - -impl From for SqliteError { - fn from(err: tokio::sync::oneshot::error::RecvError) -> Self { - SqliteError::NoCap { - error: err.to_string(), - } - } -} - -impl From> for SqliteError { - fn from(err: tokio::sync::mpsc::error::SendError) -> Self { - SqliteError::NoCap { - error: err.to_string(), - } + SqliteError::RusqliteError(err.to_string()) } }