diff --git a/Cargo.toml b/Cargo.toml index 0137e18..9849310 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,7 +37,8 @@ trace = ["libsqlite3-sys/min_sqlite_version_3_6_23"] bundled = ["libsqlite3-sys/bundled", "modern_sqlite"] buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] limits = [] -hooks = ["libsqlite3-sys/preupdate_hook"] +hooks = [] +preupdate_hook = ["libsqlite3-sys/preupdate_hook"] i128_blob = ["byteorder"] sqlcipher = ["libsqlite3-sys/sqlcipher"] unlock_notify = ["libsqlite3-sys/unlock_notify"] diff --git a/src/hooks.rs b/src/hooks.rs index d2cc1cd..99dfe02 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -1,14 +1,9 @@ //! `feature = "hooks"` Commit, Data Change and Rollback Notification Callbacks #![allow(non_camel_case_types)] - -use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, RefUnwindSafe}; -use std::ptr; +use std::os::raw::c_void; use crate::ffi; -use crate::{Connection, InnerConnection}; - /// `feature = "hooks"` Action Codes #[derive(Clone, Copy, Debug, PartialEq)] #[repr(i32)] @@ -36,452 +31,555 @@ impl From for Action { } } -impl Connection { - /// `feature = "hooks"` Register a callback function to be invoked whenever - /// a transaction is committed. - /// - /// The callback returns `true` to rollback. - #[inline] - pub fn commit_hook<'c, F>(&'c self, hook: Option) - where - F: FnMut() -> bool + Send + 'c, - { - self.db.borrow_mut().commit_hook(hook); - } - - /// `feature = "hooks"` Register a callback function to be invoked whenever - /// a transaction is committed. - /// - /// The callback returns `true` to rollback. - #[inline] - pub fn rollback_hook<'c, F>(&'c self, hook: Option) - where - F: FnMut() + Send + 'c, - { - self.db.borrow_mut().rollback_hook(hook); - } - - /// `feature = "hooks"` Register a callback function to be invoked whenever - /// a row is updated, inserted or deleted in a rowid table. - /// - /// The callback parameters are: - /// - /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or - /// SQLITE_DELETE), - /// - the name of the database ("main", "temp", ...), - /// - the name of the table that is updated, - /// - the ROWID of the row that is updated. - #[inline] - pub fn update_hook<'c, F>(&'c self, hook: Option) - where - F: FnMut(Action, &str, &str, i64) + Send + 'c, - { - self.db.borrow_mut().update_hook(hook); - } - /// - /// `feature = "hooks"` Register a callback function to be invoked before - /// a row is updated, inserted or deleted in a rowid table. - /// - /// The callback parameters are: - /// - /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or - /// SQLITE_DELETE), - /// - the name of the database ("main", "temp", ...), - /// - the name of the table that is updated, - /// - for an update or delete, the initial ROWID of the row that is going to be updated/deleted. It is undefined for inserts. - /// - for an update or insert, the final ROWID of the row that is going to be updated/inserted. It is undefined for deletes. - #[inline] - pub fn preupdate_hook<'c, F>(&'c self, hook: Option) - where - F: FnMut(Action, &str, &str, i64, i64) + Send + 'c, - { - self.db.borrow_mut().preupdate_hook(hook); - } - - /// `feature = "hooks"` Register a query progress callback. - /// - /// The parameter `num_ops` is the approximate number of virtual machine - /// instructions that are evaluated between successive invocations of the - /// `handler`. If `num_ops` is less than one then the progress handler - /// is disabled. - /// - /// If the progress callback returns `true`, the operation is interrupted. - pub fn progress_handler(&self, num_ops: c_int, handler: Option) - where - F: FnMut() -> bool + Send + RefUnwindSafe + 'static, - { - self.db.borrow_mut().progress_handler(num_ops, handler); - } -} - -impl InnerConnection { - #[inline] - pub fn remove_hooks(&mut self) { - self.update_hook(None::); - self.preupdate_hook(None::); - self.commit_hook(None:: bool>); - self.rollback_hook(None::); - self.progress_handler(0, None:: bool>); - } - - fn commit_hook<'c, F>(&'c mut self, hook: Option) - where - F: FnMut() -> bool + Send + 'c, - { - unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int - where - F: FnMut() -> bool, - { - let r = catch_unwind(|| { - let boxed_hook: *mut F = p_arg as *mut F; - (*boxed_hook)() - }); - if let Ok(true) = r { - 1 - } else { - 0 - } - } - - // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with - // `sqlite3_commit_hook`. so we keep the `xDestroy` function in - // `InnerConnection.free_boxed_hook`. - let free_commit_hook = if hook.is_some() { - Some(free_boxed_hook:: as unsafe fn(*mut c_void)) - } else { - None - }; - - let previous_hook = match hook { - Some(hook) => { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_commit_hook( - self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _, - ) - } - } - _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, - }; - if !previous_hook.is_null() { - if let Some(free_boxed_hook) = self.free_commit_hook { - unsafe { free_boxed_hook(previous_hook) }; - } - } - self.free_commit_hook = free_commit_hook; - } - - fn rollback_hook<'c, F>(&'c mut self, hook: Option) - where - F: FnMut() + Send + 'c, - { - unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) - where - F: FnMut(), - { - let _ = catch_unwind(|| { - let boxed_hook: *mut F = p_arg as *mut F; - (*boxed_hook)(); - }); - } - - let free_rollback_hook = if hook.is_some() { - Some(free_boxed_hook:: as unsafe fn(*mut c_void)) - } else { - None - }; - - let previous_hook = match hook { - Some(hook) => { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_rollback_hook( - self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _, - ) - } - } - _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, - }; - if !previous_hook.is_null() { - if let Some(free_boxed_hook) = self.free_rollback_hook { - unsafe { free_boxed_hook(previous_hook) }; - } - } - self.free_rollback_hook = free_rollback_hook; - } - - fn update_hook<'c, F>(&'c mut self, hook: Option) - where - F: FnMut(Action, &str, &str, i64) + Send + 'c, - { - unsafe extern "C" fn call_boxed_closure( - p_arg: *mut c_void, - action_code: c_int, - db_str: *const c_char, - tbl_str: *const c_char, - row_id: i64, - ) where - F: FnMut(Action, &str, &str, i64), - { - use std::ffi::CStr; - use std::str; - - let action = Action::from(action_code); - let db_name = { - let c_slice = CStr::from_ptr(db_str).to_bytes(); - str::from_utf8(c_slice) - }; - let tbl_name = { - let c_slice = CStr::from_ptr(tbl_str).to_bytes(); - str::from_utf8(c_slice) - }; - - let _ = catch_unwind(|| { - let boxed_hook: *mut F = p_arg as *mut F; - (*boxed_hook)( - action, - db_name.expect("illegal db name"), - tbl_name.expect("illegal table name"), - row_id, - ); - }); - } - - let free_update_hook = if hook.is_some() { - Some(free_boxed_hook:: as unsafe fn(*mut c_void)) - } else { - None - }; - - let previous_hook = match hook { - Some(hook) => { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_update_hook( - self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _, - ) - } - } - _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, - }; - if !previous_hook.is_null() { - if let Some(free_boxed_hook) = self.free_update_hook { - unsafe { free_boxed_hook(previous_hook) }; - } - } - self.free_update_hook = free_update_hook; - } - - fn preupdate_hook<'c, F>(&'c mut self, hook: Option) - where - F: FnMut(Action, &str, &str, i64, i64) + Send + 'c, - { - unsafe extern "C" fn call_boxed_closure( - p_arg: *mut c_void, - _sqlite: *mut ffi::sqlite3, - action_code: c_int, - db_str: *const c_char, - tbl_str: *const c_char, - row_id: i64, - new_row_id: i64, - ) where - F: FnMut(Action, &str, &str, i64, i64), - { - use std::ffi::CStr; - use std::str; - - let action = Action::from(action_code); - let db_name = { - let c_slice = CStr::from_ptr(db_str).to_bytes(); - str::from_utf8(c_slice) - }; - let tbl_name = { - let c_slice = CStr::from_ptr(tbl_str).to_bytes(); - str::from_utf8(c_slice) - }; - - let _ = catch_unwind(|| { - let boxed_hook: *mut F = p_arg as *mut F; - (*boxed_hook)( - action, - db_name.expect("illegal db name"), - tbl_name.expect("illegal table name"), - row_id, - new_row_id, - ); - }); - } - - let free_preupdate_hook = if hook.is_some() { - Some(free_boxed_hook:: as unsafe fn(*mut c_void)) - } else { - None - }; - - let previous_hook = match hook { - Some(hook) => { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_preupdate_hook( - self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _, - ) - } - } - _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) }, - }; - if !previous_hook.is_null() { - if let Some(free_boxed_hook) = self.free_preupdate_hook { - unsafe { free_boxed_hook(previous_hook) }; - } - } - self.free_preupdate_hook = free_preupdate_hook; - } - - fn progress_handler(&mut self, num_ops: c_int, handler: Option) - where - F: FnMut() -> bool + Send + RefUnwindSafe + 'static, - { - unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int - where - F: FnMut() -> bool, - { - let r = catch_unwind(|| { - let boxed_handler: *mut F = p_arg as *mut F; - (*boxed_handler)() - }); - if let Ok(true) = r { - 1 - } else { - 0 - } - } - - match handler { - Some(handler) => { - let boxed_handler = Box::new(handler); - unsafe { - ffi::sqlite3_progress_handler( - self.db(), - num_ops, - Some(call_boxed_closure::), - &*boxed_handler as *const F as *mut _, - ) - } - self.progress_handler = Some(boxed_handler); - } - _ => { - unsafe { ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) } - self.progress_handler = None; - } - }; - } -} - unsafe fn free_boxed_hook(p: *mut c_void) { drop(Box::from_raw(p as *mut F)); } -#[cfg(test)] -mod test { +#[cfg(feature = "preupdate_hook")] +mod preupdate_hook { + use super::free_boxed_hook; use super::Action; - use crate::{Connection, Result}; - use std::sync::atomic::{AtomicBool, Ordering}; - #[test] - fn test_commit_hook() -> Result<()> { - let db = Connection::open_in_memory()?; + use std::os::raw::{c_char, c_int, c_void}; + use std::panic::catch_unwind; + use std::ptr; - let mut called = false; - db.commit_hook(Some(|| { - called = true; - false - })); - db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; - assert!(called); - Ok(()) + use crate::ffi; + use crate::types::ValueRef; + use crate::{Connection, InnerConnection}; + + // TODO: how to allow user access to these functions, since they should be only accessible in + // the scope of a preupdate_hook callback. + pub struct PreUpdateHookFunctions { + db: *mut ffi::sqlite3, } - #[test] - fn test_fn_commit_hook() -> Result<()> { - let db = Connection::open_in_memory()?; - - fn hook() -> bool { - true + impl PreUpdateHookFunctions { + pub unsafe fn get_count(&self) -> i32 { + ffi::sqlite3_preupdate_count(self.db) } - db.commit_hook(Some(hook)); - db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") - .unwrap_err(); - Ok(()) - } - - #[test] - fn test_rollback_hook() -> Result<()> { - let db = Connection::open_in_memory()?; - - let mut called = false; - db.rollback_hook(Some(|| { - called = true; - })); - db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?; - assert!(called); - Ok(()) - } - - #[test] - fn test_update_hook() -> Result<()> { - let db = Connection::open_in_memory()?; - - let mut called = false; - db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { - assert_eq!(Action::SQLITE_INSERT, action); - assert_eq!("main", db); - assert_eq!("foo", tbl); - assert_eq!(1, row_id); - called = true; - })); - db.execute_batch("CREATE TABLE foo (t TEXT)")?; - db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; - assert!(called); - Ok(()) - } - - #[test] - fn test_progress_handler() -> Result<()> { - let db = Connection::open_in_memory()?; - - static CALLED: AtomicBool = AtomicBool::new(false); - db.progress_handler( - 1, - Some(|| { - CALLED.store(true, Ordering::Relaxed); - false - }), - ); - db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; - assert!(CALLED.load(Ordering::Relaxed)); - Ok(()) - } - - #[test] - fn test_progress_handler_interrupt() -> Result<()> { - let db = Connection::open_in_memory()?; - - fn handler() -> bool { - true + pub unsafe fn get_old(&self, i: i32) -> ValueRef { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + ffi::sqlite3_preupdate_old(self.db, i, &mut p_value); + ValueRef::from_value(p_value) } - db.progress_handler(1, Some(handler)); - db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") - .unwrap_err(); - Ok(()) + pub unsafe fn get_new(&self, i: i32) -> ValueRef { + let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); + ffi::sqlite3_preupdate_new(self.db, i, &mut p_value); + ValueRef::from_value(p_value) + } + } + + impl Connection { + /// + /// `feature = "preupdate_hook"` Register a callback function to be invoked before + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or + /// SQLITE_DELETE), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - for an update or delete, the initial ROWID of the row that is going to be updated/deleted. It is undefined for inserts. + /// - for an update or insert, the final ROWID of the row that is going to be updated/inserted. It is undefined for deletes. + #[inline] + pub fn preupdate_hook<'c, F>(&'c self, hook: Option) + where + F: FnMut(Action, &str, &str, i64, i64, &PreUpdateHookFunctions) + Send + 'c, + { + self.db.borrow_mut().preupdate_hook(hook); + } + } + + impl InnerConnection { + #[inline] + pub fn remove_preupdate_hook(&mut self) { + self.preupdate_hook(None::); + } + + fn preupdate_hook<'c, F>(&'c mut self, hook: Option) + where + F: FnMut(Action, &str, &str, i64, i64, &PreUpdateHookFunctions) + Send + 'c, + { + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + sqlite: *mut ffi::sqlite3, + action_code: c_int, + db_str: *const c_char, + tbl_str: *const c_char, + row_id: i64, + new_row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64, i64, &PreUpdateHookFunctions), + { + use std::ffi::CStr; + use std::str; + + let action = Action::from(action_code); + let db_name = { + let c_slice = CStr::from_ptr(db_str).to_bytes(); + str::from_utf8(c_slice) + }; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + + // TODO: how to properly allow a user to use the functions + // (sqlite3_preupdate_old,...) that are only in scope + // during the callback? + // Also how to pass in the rowids, because they can be undefined based on the + // action. + let preupdate_hook_functions = PreUpdateHookFunctions { db: sqlite }; + + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)( + action, + db_name.expect("illegal db name"), + tbl_name.expect("illegal table name"), + row_id, + new_row_id, + &preupdate_hook_functions, + ); + }); + } + + let free_preupdate_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_preupdate_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_preupdate_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_preupdate_hook = free_preupdate_hook; + } + } + + #[cfg(test)] + mod test { + use super::super::Action; + use super::PreUpdateHookFunctions; + use crate::{Connection, Result}; + + #[test] + fn test_preupdate_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut called = false; + db.preupdate_hook(Some( + |action, + db: &str, + tbl: &str, + row_id, + new_row_id, + _func: &PreUpdateHookFunctions| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + assert_eq!(1, new_row_id); + called = true; + }, + )); + db.execute_batch("CREATE TABLE foo (t TEXT)")?; + db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; + assert!(called); + Ok(()) + } + } +} + +#[cfg(feature = "hooks")] +mod datachanged_and_friends { + use super::free_boxed_hook; + use super::Action; + + use std::os::raw::{c_char, c_int, c_void}; + use std::panic::{catch_unwind, RefUnwindSafe}; + use std::ptr; + + use crate::ffi; + use crate::{Connection, InnerConnection}; + + impl Connection { + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + #[inline] + pub fn commit_hook<'c, F>(&'c self, hook: Option) + where + F: FnMut() -> bool + Send + 'c, + { + self.db.borrow_mut().commit_hook(hook); + } + + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a transaction is committed. + /// + /// The callback returns `true` to rollback. + #[inline] + pub fn rollback_hook<'c, F>(&'c self, hook: Option) + where + F: FnMut() + Send + 'c, + { + self.db.borrow_mut().rollback_hook(hook); + } + + /// `feature = "hooks"` Register a callback function to be invoked whenever + /// a row is updated, inserted or deleted in a rowid table. + /// + /// The callback parameters are: + /// + /// - the type of database update (SQLITE_INSERT, SQLITE_UPDATE or + /// SQLITE_DELETE), + /// - the name of the database ("main", "temp", ...), + /// - the name of the table that is updated, + /// - the ROWID of the row that is updated. + #[inline] + pub fn update_hook<'c, F>(&'c self, hook: Option) + where + F: FnMut(Action, &str, &str, i64) + Send + 'c, + { + self.db.borrow_mut().update_hook(hook); + } + + /// `feature = "hooks"` Register a query progress callback. + /// + /// The parameter `num_ops` is the approximate number of virtual machine + /// instructions that are evaluated between successive invocations of the + /// `handler`. If `num_ops` is less than one then the progress handler + /// is disabled. + /// + /// If the progress callback returns `true`, the operation is interrupted. + pub fn progress_handler(&self, num_ops: c_int, handler: Option) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + self.db.borrow_mut().progress_handler(num_ops, handler); + } + } + + impl InnerConnection { + #[inline] + pub fn remove_hooks(&mut self) { + self.update_hook(None::); + self.commit_hook(None:: bool>); + self.rollback_hook(None::); + self.progress_handler(0, None:: bool>); + } + + fn commit_hook<'c, F>(&'c mut self, hook: Option) + where + F: FnMut() -> bool + Send + 'c, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)() + }); + if let Ok(true) = r { + 1 + } else { + 0 + } + } + + // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with + // `sqlite3_commit_hook`. so we keep the `xDestroy` function in + // `InnerConnection.free_boxed_hook`. + let free_commit_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_commit_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_commit_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_commit_hook = free_commit_hook; + } + + fn rollback_hook<'c, F>(&'c mut self, hook: Option) + where + F: FnMut() + Send + 'c, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) + where + F: FnMut(), + { + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)(); + }); + } + + let free_rollback_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_rollback_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_rollback_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_rollback_hook = free_rollback_hook; + } + + fn update_hook<'c, F>(&'c mut self, hook: Option) + where + F: FnMut(Action, &str, &str, i64) + Send + 'c, + { + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + action_code: c_int, + db_str: *const c_char, + tbl_str: *const c_char, + row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64), + { + use std::ffi::CStr; + use std::str; + + let action = Action::from(action_code); + let db_name = { + let c_slice = CStr::from_ptr(db_str).to_bytes(); + str::from_utf8(c_slice) + }; + let tbl_name = { + let c_slice = CStr::from_ptr(tbl_str).to_bytes(); + str::from_utf8(c_slice) + }; + + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)( + action, + db_name.expect("illegal db name"), + tbl_name.expect("illegal table name"), + row_id, + ); + }); + } + + let free_update_hook = if hook.is_some() { + Some(free_boxed_hook:: as unsafe fn(*mut c_void)) + } else { + None + }; + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_update_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_update_hook { + unsafe { free_boxed_hook(previous_hook) }; + } + } + self.free_update_hook = free_update_hook; + } + + fn progress_handler(&mut self, num_ops: c_int, handler: Option) + where + F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + { + unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int + where + F: FnMut() -> bool, + { + let r = catch_unwind(|| { + let boxed_handler: *mut F = p_arg as *mut F; + (*boxed_handler)() + }); + if let Ok(true) = r { + 1 + } else { + 0 + } + } + + match handler { + Some(handler) => { + let boxed_handler = Box::new(handler); + unsafe { + ffi::sqlite3_progress_handler( + self.db(), + num_ops, + Some(call_boxed_closure::), + &*boxed_handler as *const F as *mut _, + ) + } + self.progress_handler = Some(boxed_handler); + } + _ => { + unsafe { + ffi::sqlite3_progress_handler(self.db(), num_ops, None, ptr::null_mut()) + } + self.progress_handler = None; + } + }; + } + } + + #[cfg(test)] + mod test { + use super::super::Action; + use crate::{Connection, Result}; + use std::sync::atomic::{AtomicBool, Ordering}; + + #[test] + fn test_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut called = false; + db.commit_hook(Some(|| { + called = true; + false + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(called); + Ok(()) + } + + #[test] + fn test_fn_commit_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn hook() -> bool { + true + } + + db.commit_hook(Some(hook)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } + + #[test] + fn test_rollback_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut called = false; + db.rollback_hook(Some(|| { + called = true; + })); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;")?; + assert!(called); + Ok(()) + } + + #[test] + fn test_update_hook() -> Result<()> { + let db = Connection::open_in_memory()?; + + let mut called = false; + db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + called = true; + })); + db.execute_batch("CREATE TABLE foo (t TEXT)")?; + db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; + assert!(called); + Ok(()) + } + + #[test] + fn test_progress_handler() -> Result<()> { + let db = Connection::open_in_memory()?; + + static CALLED: AtomicBool = AtomicBool::new(false); + db.progress_handler( + 1, + Some(|| { + CALLED.store(true, Ordering::Relaxed); + false + }), + ); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")?; + assert!(CALLED.load(Ordering::Relaxed)); + Ok(()) + } + + #[test] + fn test_progress_handler_interrupt() -> Result<()> { + let db = Connection::open_in_memory()?; + + fn handler() -> bool { + true + } + + db.progress_handler(1, Some(handler)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + Ok(()) + } } } diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 7579ad5..759d406 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -33,7 +33,7 @@ pub struct InnerConnection { pub free_update_hook: Option, #[cfg(feature = "hooks")] pub progress_handler: Option bool + Send>>, - #[cfg(feature = "hooks")] + #[cfg(feature = "preupdate_hook")] pub free_preupdate_hook: Option, owned: bool, } @@ -52,9 +52,9 @@ impl InnerConnection { #[cfg(feature = "hooks")] free_update_hook: None, #[cfg(feature = "hooks")] - free_preupdate_hook: None, - #[cfg(feature = "hooks")] progress_handler: None, + #[cfg(feature = "preupdate_hook")] + free_preupdate_hook: None, owned, } } @@ -155,6 +155,7 @@ impl InnerConnection { return Ok(()); } self.remove_hooks(); + self.remove_preupdate_hook(); let mut shared_handle = self.interrupt_lock.lock().unwrap(); assert!( !shared_handle.is_null(), @@ -305,6 +306,10 @@ impl InnerConnection { #[cfg(not(feature = "hooks"))] #[inline] fn remove_hooks(&mut self) {} + + #[cfg(not(feature = "preupdate_hook"))] + #[inline] + fn remove_preupdate_hook(&mut self) {} } impl Drop for InnerConnection { diff --git a/src/lib.rs b/src/lib.rs index 4635c19..7c916e4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -101,8 +101,8 @@ pub mod config; mod context; #[cfg(feature = "functions")] pub mod functions; -#[cfg(feature = "hooks")] -mod hooks; +#[cfg(any(feature = "hooks", feature = "preupdate_hook"))] +pub mod hooks; mod inner_connection; #[cfg(feature = "limits")] pub mod limits; diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index b95521b..0575b66 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -133,7 +133,12 @@ where } } -#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))] +#[cfg(any( + feature = "functions", + feature = "session", + feature = "vtab", + feature = "preupdate_hook" +))] impl<'a> ValueRef<'a> { pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> { use crate::ffi;