use std::fmt::Debug; use std::os::raw::{c_char, c_int, c_void}; use std::panic::catch_unwind; use std::ptr; use super::expect_utf8; use super::free_boxed_hook; use super::Action; use crate::error::check; use crate::ffi; use crate::inner_connection::InnerConnection; use crate::types::ValueRef; use crate::Connection; use crate::Result; /// The possible cases for when a PreUpdateHook gets triggered. Allows access to the relevant /// functions for each case through the contained values. #[derive(Debug)] pub enum PreUpdateCase { /// Pre-update hook was triggered by an insert. Insert(PreUpdateNewValueAccessor), /// Pre-update hook was triggered by a delete. Delete(PreUpdateOldValueAccessor), /// Pre-update hook was triggered by an update. Update { #[allow(missing_docs)] old_value_accessor: PreUpdateOldValueAccessor, #[allow(missing_docs)] new_value_accessor: PreUpdateNewValueAccessor, }, /// This variant is not normally produced by SQLite. You may encounter it /// if you're using a different version than what's supported by this library. Unknown, } impl From for Action { fn from(puc: PreUpdateCase) -> Action { match puc { PreUpdateCase::Insert(_) => Action::SQLITE_INSERT, PreUpdateCase::Delete(_) => Action::SQLITE_DELETE, PreUpdateCase::Update { .. } => Action::SQLITE_UPDATE, PreUpdateCase::Unknown => Action::UNKNOWN, } } } /// An accessor to access the old values of the row being deleted/updated during the preupdate callback. #[derive(Debug)] pub struct PreUpdateOldValueAccessor { db: *mut ffi::sqlite3, old_row_id: i64, } impl PreUpdateOldValueAccessor { /// Get the amount of columns in the row being deleted/updated. pub fn get_column_count(&self) -> i32 { unsafe { ffi::sqlite3_preupdate_count(self.db) } } /// Get the depth of the query that triggered the preupdate hook. /// Returns 0 if the preupdate callback was invoked as a result of /// a direct insert, update, or delete operation; /// 1 for inserts, updates, or deletes invoked by top-level triggers; /// 2 for changes resulting from triggers called by top-level triggers; and so forth. pub fn get_query_depth(&self) -> i32 { unsafe { ffi::sqlite3_preupdate_depth(self.db) } } /// Get the row id of the row being updated/deleted. pub fn get_old_row_id(&self) -> i64 { self.old_row_id } /// Get the value of the row being updated/deleted at the specified index. pub fn get_old_column_value(&self, i: i32) -> Result { let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); unsafe { check(ffi::sqlite3_preupdate_old(self.db, i, &mut p_value))?; Ok(ValueRef::from_value(p_value)) } } } /// An accessor to access the new values of the row being inserted/updated /// during the preupdate callback. #[derive(Debug)] pub struct PreUpdateNewValueAccessor { db: *mut ffi::sqlite3, new_row_id: i64, } impl PreUpdateNewValueAccessor { /// Get the amount of columns in the row being inserted/updated. pub fn get_column_count(&self) -> i32 { unsafe { ffi::sqlite3_preupdate_count(self.db) } } /// Get the depth of the query that triggered the preupdate hook. /// Returns 0 if the preupdate callback was invoked as a result of /// a direct insert, update, or delete operation; /// 1 for inserts, updates, or deletes invoked by top-level triggers; /// 2 for changes resulting from triggers called by top-level triggers; and so forth. pub fn get_query_depth(&self) -> i32 { unsafe { ffi::sqlite3_preupdate_depth(self.db) } } /// Get the row id of the row being inserted/updated. pub fn get_new_row_id(&self) -> i64 { self.new_row_id } /// Get the value of the row being updated/deleted at the specified index. pub fn get_new_column_value(&self, i: i32) -> Result { let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut(); unsafe { check(ffi::sqlite3_preupdate_new(self.db, i, &mut p_value))?; Ok(ValueRef::from_value(p_value)) } } } impl Connection { /// Register a callback function to be invoked before /// a row is updated, inserted or deleted. /// /// The callback parameters are: /// /// - the name of the database ("main", "temp", ...), /// - the name of the table that is updated, /// - a variant of the PreUpdateCase enum which allows access to extra functions depending /// on whether it's an update, delete or insert. #[inline] pub fn preupdate_hook(&self, hook: Option) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static, { self.db.borrow_mut().preupdate_hook(hook); } } impl InnerConnection { #[inline] pub fn remove_preupdate_hook(&mut self) { self.preupdate_hook(None::); } /// ```compile_fail /// use rusqlite::{Connection, Result, hooks::PreUpdateCase}; /// fn main() -> Result<()> { /// let db = Connection::open_in_memory()?; /// { /// let mut called = std::sync::atomic::AtomicBool::new(false); /// db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { /// called.store(true, std::sync::atomic::Ordering::Relaxed); /// })); /// } /// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;") /// } /// ``` fn preupdate_hook(&mut self, hook: Option) where F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static, { unsafe extern "C" fn call_boxed_closure( p_arg: *mut c_void, sqlite: *mut ffi::sqlite3, action_code: c_int, db_name: *const c_char, tbl_name: *const c_char, old_row_id: i64, new_row_id: i64, ) where F: FnMut(Action, &str, &str, &PreUpdateCase), { let action = Action::from(action_code); let preupdate_case = match action { Action::SQLITE_INSERT => PreUpdateCase::Insert(PreUpdateNewValueAccessor { db: sqlite, new_row_id, }), Action::SQLITE_DELETE => PreUpdateCase::Delete(PreUpdateOldValueAccessor { db: sqlite, old_row_id, }), Action::SQLITE_UPDATE => PreUpdateCase::Update { old_value_accessor: PreUpdateOldValueAccessor { db: sqlite, old_row_id, }, new_value_accessor: PreUpdateNewValueAccessor { db: sqlite, new_row_id, }, }, Action::UNKNOWN => PreUpdateCase::Unknown, }; drop(catch_unwind(|| { let boxed_hook: *mut F = p_arg.cast::(); (*boxed_hook)( action, expect_utf8(db_name, "database name"), expect_utf8(tbl_name, "table name"), &preupdate_case, ); })); } let free_preupdate_hook = if hook.is_some() { Some(free_boxed_hook:: as unsafe fn(*mut c_void)) } else { None }; let previous_hook = match hook { Some(hook) => { let boxed_hook: *mut F = Box::into_raw(Box::new(hook)); unsafe { ffi::sqlite3_preupdate_hook( self.db(), Some(call_boxed_closure::), boxed_hook.cast(), ) } } _ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) }, }; if !previous_hook.is_null() { if let Some(free_boxed_hook) = self.free_preupdate_hook { unsafe { free_boxed_hook(previous_hook) }; } } self.free_preupdate_hook = free_preupdate_hook; } } #[cfg(test)] mod test { use std::sync::atomic::{AtomicBool, Ordering}; use super::super::Action; use super::PreUpdateCase; use crate::{Connection, Result}; #[test] fn test_preupdate_hook_insert() -> Result<()> { let db = Connection::open_in_memory()?; static CALLED: AtomicBool = AtomicBool::new(false); db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { assert_eq!(Action::SQLITE_INSERT, action); assert_eq!("main", db); assert_eq!("foo", tbl); match case { PreUpdateCase::Insert(accessor) => { assert_eq!(1, accessor.get_column_count()); assert_eq!(1, accessor.get_new_row_id()); assert_eq!(0, accessor.get_query_depth()); // out of bounds access should return an error assert!(accessor.get_new_column_value(1).is_err()); assert_eq!( "lisa", accessor.get_new_column_value(0).unwrap().as_str().unwrap() ); assert_eq!(0, accessor.get_query_depth()); } _ => panic!("wrong preupdate case"), } CALLED.store(true, Ordering::Relaxed); })); db.execute_batch("CREATE TABLE foo (t TEXT)")?; db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } #[test] fn test_preupdate_hook_delete() -> Result<()> { let db = Connection::open_in_memory()?; static CALLED: AtomicBool = AtomicBool::new(false); db.execute_batch("CREATE TABLE foo (t TEXT)")?; db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { assert_eq!(Action::SQLITE_DELETE, action); assert_eq!("main", db); assert_eq!("foo", tbl); match case { PreUpdateCase::Delete(accessor) => { assert_eq!(1, accessor.get_column_count()); assert_eq!(1, accessor.get_old_row_id()); assert_eq!(0, accessor.get_query_depth()); // out of bounds access should return an error assert!(accessor.get_old_column_value(1).is_err()); assert_eq!( "lisa", accessor.get_old_column_value(0).unwrap().as_str().unwrap() ); assert_eq!(0, accessor.get_query_depth()); } _ => panic!("wrong preupdate case"), } CALLED.store(true, Ordering::Relaxed); })); db.execute_batch("DELETE from foo")?; assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } #[test] fn test_preupdate_hook_update() -> Result<()> { let db = Connection::open_in_memory()?; static CALLED: AtomicBool = AtomicBool::new(false); db.execute_batch("CREATE TABLE foo (t TEXT)")?; db.execute_batch("INSERT INTO foo VALUES ('lisa')")?; db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| { assert_eq!(Action::SQLITE_UPDATE, action); assert_eq!("main", db); assert_eq!("foo", tbl); match case { PreUpdateCase::Update { old_value_accessor, new_value_accessor, } => { assert_eq!(1, old_value_accessor.get_column_count()); assert_eq!(1, old_value_accessor.get_old_row_id()); assert_eq!(0, old_value_accessor.get_query_depth()); // out of bounds access should return an error assert!(old_value_accessor.get_old_column_value(1).is_err()); assert_eq!( "lisa", old_value_accessor .get_old_column_value(0) .unwrap() .as_str() .unwrap() ); assert_eq!(0, old_value_accessor.get_query_depth()); assert_eq!(1, new_value_accessor.get_column_count()); assert_eq!(1, new_value_accessor.get_new_row_id()); assert_eq!(0, new_value_accessor.get_query_depth()); // out of bounds access should return an error assert!(new_value_accessor.get_new_column_value(1).is_err()); assert_eq!( "janice", new_value_accessor .get_new_column_value(0) .unwrap() .as_str() .unwrap() ); assert_eq!(0, new_value_accessor.get_query_depth()); } _ => panic!("wrong preupdate case"), } CALLED.store(true, Ordering::Relaxed); })); db.execute_batch("UPDATE foo SET t = 'janice'")?; assert!(CALLED.load(Ordering::Relaxed)); Ok(()) } }