diff --git a/src/functions.rs b/src/functions.rs index 91aecf2..10db830 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -182,7 +182,7 @@ impl<'a> ValueRef<'a> { } unsafe extern "C" fn free_boxed_value(p: *mut c_void) { - let _: Box = Box::from_raw(p as *mut T); + drop(Box::from_raw(p as *mut T)); } /// Context is a wrapper for the SQLite function evaluation context. diff --git a/src/hooks.rs b/src/hooks.rs index 86f4807..b3e2fa6 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -94,8 +94,9 @@ impl Connection { /// Register a callback function to be invoked whenever a transaction is committed. /// /// The callback returns `true` to rollback. - pub fn commit_hook(&self, hook: F) - where F: FnMut() -> bool + pub fn commit_hook(&self, hook: Option) + where + F: FnMut() -> bool, { self.db.borrow_mut().commit_hook(hook); } @@ -103,8 +104,9 @@ impl Connection { /// Register a callback function to be invoked whenever a transaction is committed. /// /// The callback returns `true` to rollback. - pub fn rollback_hook(&self, hook: F) - where F: FnMut() + pub fn rollback_hook(&self, hook: Option) + where + F: FnMut(), { self.db.borrow_mut().rollback_hook(hook); } @@ -118,99 +120,122 @@ impl Connection { /// - the name of the database ("main", "temp", ...), /// - the name of the table that is updated, /// - the ROWID of the row that is updated. - pub fn update_hook(&self, hook: F) - where F: FnMut(Action, &str, &str, i64) + pub fn update_hook(&self, hook: Option) + where + F: FnMut(Action, &str, &str, i64), { self.db.borrow_mut().update_hook(hook); } - - /// Remove hook installed by `update_hook`. - pub fn remove_update_hook(&self) { - self.db.borrow_mut().remove_update_hook(); - } - - /// Remove hook installed by `commit_hook`. - pub fn remove_commit_hook(&self) { - self.db.borrow_mut().remove_commit_hook(); - } - - /// Remove hook installed by `rollback_hook`. - pub fn remove_rollback_hook(&self) { - self.db.borrow_mut().remove_rollback_hook(); - } } impl InnerConnection { pub fn remove_hooks(&mut self) { - self.remove_update_hook(); - self.remove_commit_hook(); - self.remove_rollback_hook(); + self.update_hook(None::); + self.commit_hook(None:: bool>); + self.rollback_hook(None::); } - fn commit_hook(&self, hook: F) - where F: FnMut() -> bool + fn commit_hook(&mut self, hook: Option) + where + F: FnMut() -> bool, { unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) -> c_int - where F: FnMut() -> bool + where + F: FnMut() -> bool, { let boxed_hook: *mut F = p_arg as *mut F; - assert!(!boxed_hook.is_null(), - "Internal error - null function pointer"); - - if (*boxed_hook)() { 1 } else { 0 } + if (*boxed_hook)() { + 1 + } else { + 0 + } } - let previous_hook = { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_commit_hook(self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _) - } + // unlike `sqlite3_create_function_v2`, we cannot specify a `xDestroy` with `sqlite3_commit_hook`. + // so we keep the `xDestroy` function in `InnerConnection.free_boxed_hook`. + let free_commit_hook = if hook.is_some() { + Some(free_boxed_hook:: as fn(*mut c_void)) + } else { + None }; - free_boxed_hook(previous_hook); + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_commit_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_commit_hook { + free_boxed_hook(previous_hook); + } + } + self.free_commit_hook = free_commit_hook; } - fn rollback_hook(&self, hook: F) - where F: FnMut() + fn rollback_hook(&mut self, hook: Option) + where + F: FnMut(), { unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void) - where F: FnMut() + where + F: FnMut(), { let boxed_hook: *mut F = p_arg as *mut F; - assert!(!boxed_hook.is_null(), - "Internal error - null function pointer"); - (*boxed_hook)(); } - let previous_hook = { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_rollback_hook(self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _) - } + let free_rollback_hook = if hook.is_some() { + Some(free_boxed_hook:: as fn(*mut c_void)) + } else { + None }; - free_boxed_hook(previous_hook); + + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_rollback_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_rollback_hook { + free_boxed_hook(previous_hook); + } + } + self.free_rollback_hook = free_rollback_hook; } - fn update_hook(&mut self, hook: F) - where F: FnMut(Action, &str, &str, i64) + fn update_hook(&mut self, hook: Option) + where + F: FnMut(Action, &str, &str, i64), { - unsafe extern "C" fn call_boxed_closure(p_arg: *mut c_void, - action_code: c_int, - db_str: *const c_char, - tbl_str: *const c_char, - row_id: i64) - where F: FnMut(Action, &str, &str, i64) + unsafe extern "C" fn call_boxed_closure( + p_arg: *mut c_void, + action_code: c_int, + db_str: *const c_char, + tbl_str: *const c_char, + row_id: i64, + ) where + F: FnMut(Action, &str, &str, i64), { use std::ffi::CStr; use std::str; let boxed_hook: *mut F = p_arg as *mut F; - assert!(!boxed_hook.is_null(), - "Internal error - null function pointer"); let action = Action::from(action_code); let db_name = { @@ -225,38 +250,36 @@ impl InnerConnection { (*boxed_hook)(action, db_name, tbl_name, row_id); } - let previous_hook = { - let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); - unsafe { - ffi::sqlite3_update_hook(self.db(), - Some(call_boxed_closure::), - boxed_hook as *mut _) - } + let free_update_hook = if hook.is_some() { + Some(free_boxed_hook:: as fn(*mut c_void)) + } else { + None }; - free_boxed_hook(previous_hook); - } - fn remove_update_hook(&mut self) { - let previous_hook = unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }; - free_boxed_hook(previous_hook); - } - - fn remove_commit_hook(&mut self) { - let previous_hook = unsafe { ffi::sqlite3_commit_hook(self.db(), None, ptr::null_mut()) }; - free_boxed_hook(previous_hook); - } - - fn remove_rollback_hook(&mut self) { - let previous_hook = unsafe { ffi::sqlite3_rollback_hook(self.db(), None, ptr::null_mut()) }; - free_boxed_hook(previous_hook); + let previous_hook = match hook { + Some(hook) => { + let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); + unsafe { + ffi::sqlite3_update_hook( + self.db(), + Some(call_boxed_closure::), + boxed_hook as *mut _, + ) + } + } + _ => unsafe { ffi::sqlite3_update_hook(self.db(), None, ptr::null_mut()) }, + }; + if !previous_hook.is_null() { + if let Some(free_boxed_hook) = self.free_update_hook { + free_boxed_hook(previous_hook); + } + } + self.free_update_hook = free_update_hook; } } -fn free_boxed_hook(hook: *mut c_void) { - if !hook.is_null() { - // TODO make sure that size_of::<*mut F>() is always equal to size_of::<*mut c_void>() - let _: Box<*mut c_void> = unsafe { Box::from_raw(hook as *mut _) }; - } +fn free_boxed_hook(p: *mut c_void) { + drop(unsafe { Box::from_raw(p as *mut F) }); } #[cfg(test)] @@ -269,21 +292,36 @@ mod test { let db = Connection::open_in_memory().unwrap(); let mut called = false; - db.commit_hook(|| { - called = true; - false - }); + db.commit_hook(Some(|| { + called = true; + false + })); db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") .unwrap(); assert!(called); } + #[test] + fn test_fn_commit_hook() { + let db = Connection::open_in_memory().unwrap(); + + fn hook() -> bool { + true + } + + db.commit_hook(Some(hook)); + db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + .unwrap_err(); + } + #[test] fn test_rollback_hook() { let db = Connection::open_in_memory().unwrap(); let mut called = false; - db.rollback_hook(|| { called = true; }); + db.rollback_hook(Some(|| { + called = true; + })); db.execute_batch("BEGIN; CREATE TABLE foo (t TEXT); ROLLBACK;") .unwrap(); assert!(called); @@ -294,13 +332,13 @@ mod test { let db = Connection::open_in_memory().unwrap(); let mut called = false; - db.update_hook(|action, db, tbl, row_id| { - assert_eq!(Action::SQLITE_INSERT, action); - assert_eq!("main", db); - assert_eq!("foo", tbl); - assert_eq!(1, row_id); - called = true; - }); + db.update_hook(Some(|action, db: &str, tbl: &str, row_id| { + assert_eq!(Action::SQLITE_INSERT, action); + assert_eq!("main", db); + assert_eq!("foo", tbl); + assert_eq!(1, row_id); + called = true; + })); db.execute_batch("CREATE TABLE foo (t TEXT)").unwrap(); db.execute_batch("INSERT INTO foo VALUES ('lisa')").unwrap(); assert!(called); diff --git a/src/lib.rs b/src/lib.rs index 598520b..eb5bf4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -72,7 +72,7 @@ use std::result; use std::str; use std::sync::{Once, ONCE_INIT}; use std::sync::atomic::{AtomicBool, ATOMIC_BOOL_INIT, Ordering}; -use std::os::raw::{c_int, c_char}; +use std::os::raw::{c_int, c_char, c_void}; use types::{ToSql, ValueRef}; use error::{error_from_sqlite_code, error_from_handle}; @@ -597,6 +597,12 @@ impl fmt::Debug for Connection { struct InnerConnection { db: *mut ffi::sqlite3, + #[cfg(feature = "hooks")] + free_commit_hook: Option, + #[cfg(feature = "hooks")] + free_rollback_hook: Option, + #[cfg(feature = "hooks")] + free_update_hook: Option, } /// Old name for `OpenFlags`. `SqliteOpenFlags` is deprecated. @@ -755,6 +761,15 @@ To fix this, either: } impl InnerConnection { + #[cfg(not(feature = "hooks"))] + fn new(db: *mut ffi::sqlite3) -> InnerConnection { + InnerConnection { db } + } + #[cfg(feature = "hooks")] + fn new(db: *mut ffi::sqlite3) -> InnerConnection { + InnerConnection { db, free_commit_hook: None, free_rollback_hook: None, free_update_hook: None } + } + fn open_with_flags(c_path: &CString, flags: OpenFlags) -> Result { ensure_valid_sqlite_version(); ensure_safe_sqlite_threading_mode()?; @@ -792,7 +807,7 @@ impl InnerConnection { // attempt to turn on extended results code; don't fail if we can't. ffi::sqlite3_extended_result_codes(db, 1); - Ok(InnerConnection { db }) + Ok(InnerConnection::new(db)) } }