diff --git a/Cargo.toml b/Cargo.toml index fb951ef..0137e18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,7 @@ trace = ["libsqlite3-sys/min_sqlite_version_3_6_23"] bundled = ["libsqlite3-sys/bundled", "modern_sqlite"] buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] limits = [] -hooks = [] +hooks = ["libsqlite3-sys/preupdate_hook"] i128_blob = ["byteorder"] sqlcipher = ["libsqlite3-sys/sqlcipher"] unlock_notify = ["libsqlite3-sys/unlock_notify"] diff --git a/src/hooks.rs b/src/hooks.rs index b2ed430..1f5ae8a 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -78,6 +78,25 @@ impl Connection { { self.db.borrow_mut().update_hook(hook); } + /// + /// `feature = "hooks"` Register a callback function to be invoked before + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or + /// SQLITE_DELETE), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - for an update or delete, the initial ROWID of the row that is going to be updated/deleted. It is undefined for inserts. + /// - for an update or insert, the final ROWID of the row that is going to be updated/inserted. It is undefined for deletes. + #[inline] + pub fn preupdate_hook<'c, F>(&'c self, hook: Option) + where + F: FnMut(Action, &str, &str, i64, i64) + Send + 'c, + { + self.db.borrow_mut().preupdate_hook(hook); + } /// `feature = "hooks"` Register a query progress callback. /// @@ -99,6 +118,7 @@ impl InnerConnection { #[inline] pub fn remove_hooks(&mut self) { self.update_hook(None::); + self.preupdate_hook(None::); self.commit_hook(None:: bool>); self.rollback_hook(None::); self.progress_handler(0, None:: bool>); @@ -258,6 +278,73 @@ impl InnerConnection { self.free_update_hook = free_update_hook; } + fn preupdate_hook<'c, F>(&'c mut self, hook: Option) + where + F: FnMut(Action, &str, &str, i64, i64) + Send + 'c, + { + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + sqlite: *mut ffi::sqlite3, + action_code: c_int, + db_str: *const c_char, + tbl_str: *const c_char, + row_id: i64, + new_row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64, i64), + { + use std::ffi::CStr; + use std::str; + + let action = Action::from(action_code); + let db_name = { + let c_slice = CStr::from_ptr(db_str).to_bytes(); + str::from_utf8(c_slice) + }; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)( + action, + db_name.expect("illegal db name"), + tbl_name.expect("illegal table name"), + row_id, + new_row_id, + ); + }); + } + + let free_preupdate_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_preupdate_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_preupdate_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_preupdate_hook = free_preupdate_hook; + } + fn progress_handler(&mut self, num_ops: c_int, handler: Option) where F: FnMut() -> bool + Send + RefUnwindSafe + 'static, diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 133b7ef..7579ad5 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -33,6 +33,8 @@ pub struct InnerConnection { pub free_update_hook: Option, #[cfg(feature = "hooks")] pub progress_handler: Option bool + Send>>, + #[cfg(feature = "hooks")] + pub free_preupdate_hook: Option, owned: bool, } @@ -50,6 +52,8 @@ impl InnerConnection { #[cfg(feature = "hooks")] free_update_hook: None, #[cfg(feature = "hooks")] + free_preupdate_hook: None, + #[cfg(feature = "hooks")] progress_handler: None, owned, }