mirror of
https://github.com/isar/rusqlite.git
synced 2025-04-04 14:24:27 +08:00
373 lines
13 KiB
Rust
373 lines
13 KiB
Rust
use std::fmt::Debug;
|
|
use std::os::raw::{c_char, c_int, c_void};
|
|
use std::panic::catch_unwind;
|
|
use std::ptr;
|
|
|
|
use super::expect_utf8;
|
|
use super::free_boxed_hook;
|
|
use super::Action;
|
|
use crate::error::check;
|
|
use crate::ffi;
|
|
use crate::inner_connection::InnerConnection;
|
|
use crate::types::ValueRef;
|
|
use crate::Connection;
|
|
use crate::Result;
|
|
|
|
/// The possible cases for when a PreUpdateHook gets triggered. Allows access to the relevant
|
|
/// functions for each case through the contained values.
|
|
#[derive(Debug)]
|
|
pub enum PreUpdateCase {
|
|
/// Pre-update hook was triggered by an insert.
|
|
Insert(PreUpdateNewValueAccessor),
|
|
/// Pre-update hook was triggered by a delete.
|
|
Delete(PreUpdateOldValueAccessor),
|
|
/// Pre-update hook was triggered by an update.
|
|
Update {
|
|
#[allow(missing_docs)]
|
|
old_value_accessor: PreUpdateOldValueAccessor,
|
|
#[allow(missing_docs)]
|
|
new_value_accessor: PreUpdateNewValueAccessor,
|
|
},
|
|
/// This variant is not normally produced by SQLite. You may encounter it
|
|
/// if you're using a different version than what's supported by this library.
|
|
Unknown,
|
|
}
|
|
|
|
impl From<PreUpdateCase> for Action {
|
|
fn from(puc: PreUpdateCase) -> Action {
|
|
match puc {
|
|
PreUpdateCase::Insert(_) => Action::SQLITE_INSERT,
|
|
PreUpdateCase::Delete(_) => Action::SQLITE_DELETE,
|
|
PreUpdateCase::Update { .. } => Action::SQLITE_UPDATE,
|
|
PreUpdateCase::Unknown => Action::UNKNOWN,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// An accessor to access the old values of the row being deleted/updated during the preupdate callback.
|
|
#[derive(Debug)]
|
|
pub struct PreUpdateOldValueAccessor {
|
|
db: *mut ffi::sqlite3,
|
|
old_row_id: i64,
|
|
}
|
|
|
|
impl PreUpdateOldValueAccessor {
|
|
/// Get the amount of columns in the row being deleted/updated.
|
|
pub fn get_column_count(&self) -> i32 {
|
|
unsafe { ffi::sqlite3_preupdate_count(self.db) }
|
|
}
|
|
|
|
/// Get the depth of the query that triggered the preupdate hook.
|
|
/// Returns 0 if the preupdate callback was invoked as a result of
|
|
/// a direct insert, update, or delete operation;
|
|
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
|
|
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
|
|
pub fn get_query_depth(&self) -> i32 {
|
|
unsafe { ffi::sqlite3_preupdate_depth(self.db) }
|
|
}
|
|
|
|
/// Get the row id of the row being updated/deleted.
|
|
pub fn get_old_row_id(&self) -> i64 {
|
|
self.old_row_id
|
|
}
|
|
|
|
/// Get the value of the row being updated/deleted at the specified index.
|
|
pub fn get_old_column_value(&self, i: i32) -> Result<ValueRef> {
|
|
let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
|
|
unsafe {
|
|
check(ffi::sqlite3_preupdate_old(self.db, i, &mut p_value))?;
|
|
Ok(ValueRef::from_value(p_value))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// An accessor to access the new values of the row being inserted/updated
|
|
/// during the preupdate callback.
|
|
#[derive(Debug)]
|
|
pub struct PreUpdateNewValueAccessor {
|
|
db: *mut ffi::sqlite3,
|
|
new_row_id: i64,
|
|
}
|
|
|
|
impl PreUpdateNewValueAccessor {
|
|
/// Get the amount of columns in the row being inserted/updated.
|
|
pub fn get_column_count(&self) -> i32 {
|
|
unsafe { ffi::sqlite3_preupdate_count(self.db) }
|
|
}
|
|
|
|
/// Get the depth of the query that triggered the preupdate hook.
|
|
/// Returns 0 if the preupdate callback was invoked as a result of
|
|
/// a direct insert, update, or delete operation;
|
|
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
|
|
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
|
|
pub fn get_query_depth(&self) -> i32 {
|
|
unsafe { ffi::sqlite3_preupdate_depth(self.db) }
|
|
}
|
|
|
|
/// Get the row id of the row being inserted/updated.
|
|
pub fn get_new_row_id(&self) -> i64 {
|
|
self.new_row_id
|
|
}
|
|
|
|
/// Get the value of the row being updated/deleted at the specified index.
|
|
pub fn get_new_column_value(&self, i: i32) -> Result<ValueRef> {
|
|
let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
|
|
unsafe {
|
|
check(ffi::sqlite3_preupdate_new(self.db, i, &mut p_value))?;
|
|
Ok(ValueRef::from_value(p_value))
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Connection {
|
|
/// Register a callback function to be invoked before
|
|
/// a row is updated, inserted or deleted.
|
|
///
|
|
/// The callback parameters are:
|
|
///
|
|
/// - the name of the database ("main", "temp", ...),
|
|
/// - the name of the table that is updated,
|
|
/// - a variant of the PreUpdateCase enum which allows access to extra functions depending
|
|
/// on whether it's an update, delete or insert.
|
|
#[inline]
|
|
pub fn preupdate_hook<F>(&self, hook: Option<F>)
|
|
where
|
|
F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
|
|
{
|
|
self.db.borrow_mut().preupdate_hook(hook);
|
|
}
|
|
}
|
|
|
|
impl InnerConnection {
|
|
#[inline]
|
|
pub fn remove_preupdate_hook(&mut self) {
|
|
self.preupdate_hook(None::<fn(Action, &str, &str, &PreUpdateCase)>);
|
|
}
|
|
|
|
/// ```compile_fail
|
|
/// use rusqlite::{Connection, Result, hooks::PreUpdateCase};
|
|
/// fn main() -> Result<()> {
|
|
/// let db = Connection::open_in_memory()?;
|
|
/// {
|
|
/// let mut called = std::sync::atomic::AtomicBool::new(false);
|
|
/// db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
|
|
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
|
|
/// }));
|
|
/// }
|
|
/// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
|
|
/// }
|
|
/// ```
|
|
fn preupdate_hook<F>(&mut self, hook: Option<F>)
|
|
where
|
|
F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
|
|
{
|
|
unsafe extern "C" fn call_boxed_closure<F>(
|
|
p_arg: *mut c_void,
|
|
sqlite: *mut ffi::sqlite3,
|
|
action_code: c_int,
|
|
db_name: *const c_char,
|
|
tbl_name: *const c_char,
|
|
old_row_id: i64,
|
|
new_row_id: i64,
|
|
) where
|
|
F: FnMut(Action, &str, &str, &PreUpdateCase),
|
|
{
|
|
let action = Action::from(action_code);
|
|
|
|
let preupdate_case = match action {
|
|
Action::SQLITE_INSERT => PreUpdateCase::Insert(PreUpdateNewValueAccessor {
|
|
db: sqlite,
|
|
new_row_id,
|
|
}),
|
|
Action::SQLITE_DELETE => PreUpdateCase::Delete(PreUpdateOldValueAccessor {
|
|
db: sqlite,
|
|
old_row_id,
|
|
}),
|
|
Action::SQLITE_UPDATE => PreUpdateCase::Update {
|
|
old_value_accessor: PreUpdateOldValueAccessor {
|
|
db: sqlite,
|
|
old_row_id,
|
|
},
|
|
new_value_accessor: PreUpdateNewValueAccessor {
|
|
db: sqlite,
|
|
new_row_id,
|
|
},
|
|
},
|
|
Action::UNKNOWN => PreUpdateCase::Unknown,
|
|
};
|
|
|
|
drop(catch_unwind(|| {
|
|
let boxed_hook: *mut F = p_arg.cast::<F>();
|
|
(*boxed_hook)(
|
|
action,
|
|
expect_utf8(db_name, "database name"),
|
|
expect_utf8(tbl_name, "table name"),
|
|
&preupdate_case,
|
|
);
|
|
}));
|
|
}
|
|
|
|
let free_preupdate_hook = if hook.is_some() {
|
|
Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let previous_hook = match hook {
|
|
Some(hook) => {
|
|
let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
|
|
unsafe {
|
|
ffi::sqlite3_preupdate_hook(
|
|
self.db(),
|
|
Some(call_boxed_closure::<F>),
|
|
boxed_hook.cast(),
|
|
)
|
|
}
|
|
}
|
|
_ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) },
|
|
};
|
|
if !previous_hook.is_null() {
|
|
if let Some(free_boxed_hook) = self.free_preupdate_hook {
|
|
unsafe { free_boxed_hook(previous_hook) };
|
|
}
|
|
}
|
|
self.free_preupdate_hook = free_preupdate_hook;
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
|
|
use super::super::Action;
|
|
use super::PreUpdateCase;
|
|
use crate::{Connection, Result};
|
|
|
|
#[test]
|
|
fn test_preupdate_hook_insert() -> Result<()> {
|
|
let db = Connection::open_in_memory()?;
|
|
|
|
static CALLED: AtomicBool = AtomicBool::new(false);
|
|
|
|
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
|
|
assert_eq!(Action::SQLITE_INSERT, action);
|
|
assert_eq!("main", db);
|
|
assert_eq!("foo", tbl);
|
|
match case {
|
|
PreUpdateCase::Insert(accessor) => {
|
|
assert_eq!(1, accessor.get_column_count());
|
|
assert_eq!(1, accessor.get_new_row_id());
|
|
assert_eq!(0, accessor.get_query_depth());
|
|
// out of bounds access should return an error
|
|
assert!(accessor.get_new_column_value(1).is_err());
|
|
assert_eq!(
|
|
"lisa",
|
|
accessor.get_new_column_value(0).unwrap().as_str().unwrap()
|
|
);
|
|
assert_eq!(0, accessor.get_query_depth());
|
|
}
|
|
_ => panic!("wrong preupdate case"),
|
|
}
|
|
CALLED.store(true, Ordering::Relaxed);
|
|
}));
|
|
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
|
|
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
|
|
assert!(CALLED.load(Ordering::Relaxed));
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_preupdate_hook_delete() -> Result<()> {
|
|
let db = Connection::open_in_memory()?;
|
|
|
|
static CALLED: AtomicBool = AtomicBool::new(false);
|
|
|
|
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
|
|
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
|
|
|
|
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
|
|
assert_eq!(Action::SQLITE_DELETE, action);
|
|
assert_eq!("main", db);
|
|
assert_eq!("foo", tbl);
|
|
match case {
|
|
PreUpdateCase::Delete(accessor) => {
|
|
assert_eq!(1, accessor.get_column_count());
|
|
assert_eq!(1, accessor.get_old_row_id());
|
|
assert_eq!(0, accessor.get_query_depth());
|
|
// out of bounds access should return an error
|
|
assert!(accessor.get_old_column_value(1).is_err());
|
|
assert_eq!(
|
|
"lisa",
|
|
accessor.get_old_column_value(0).unwrap().as_str().unwrap()
|
|
);
|
|
assert_eq!(0, accessor.get_query_depth());
|
|
}
|
|
_ => panic!("wrong preupdate case"),
|
|
}
|
|
CALLED.store(true, Ordering::Relaxed);
|
|
}));
|
|
|
|
db.execute_batch("DELETE from foo")?;
|
|
assert!(CALLED.load(Ordering::Relaxed));
|
|
Ok(())
|
|
}
|
|
|
|
#[test]
|
|
fn test_preupdate_hook_update() -> Result<()> {
|
|
let db = Connection::open_in_memory()?;
|
|
|
|
static CALLED: AtomicBool = AtomicBool::new(false);
|
|
|
|
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
|
|
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
|
|
|
|
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
|
|
assert_eq!(Action::SQLITE_UPDATE, action);
|
|
assert_eq!("main", db);
|
|
assert_eq!("foo", tbl);
|
|
match case {
|
|
PreUpdateCase::Update {
|
|
old_value_accessor,
|
|
new_value_accessor,
|
|
} => {
|
|
assert_eq!(1, old_value_accessor.get_column_count());
|
|
assert_eq!(1, old_value_accessor.get_old_row_id());
|
|
assert_eq!(0, old_value_accessor.get_query_depth());
|
|
// out of bounds access should return an error
|
|
assert!(old_value_accessor.get_old_column_value(1).is_err());
|
|
assert_eq!(
|
|
"lisa",
|
|
old_value_accessor
|
|
.get_old_column_value(0)
|
|
.unwrap()
|
|
.as_str()
|
|
.unwrap()
|
|
);
|
|
assert_eq!(0, old_value_accessor.get_query_depth());
|
|
|
|
assert_eq!(1, new_value_accessor.get_column_count());
|
|
assert_eq!(1, new_value_accessor.get_new_row_id());
|
|
assert_eq!(0, new_value_accessor.get_query_depth());
|
|
// out of bounds access should return an error
|
|
assert!(new_value_accessor.get_new_column_value(1).is_err());
|
|
assert_eq!(
|
|
"janice",
|
|
new_value_accessor
|
|
.get_new_column_value(0)
|
|
.unwrap()
|
|
.as_str()
|
|
.unwrap()
|
|
);
|
|
assert_eq!(0, new_value_accessor.get_query_depth());
|
|
}
|
|
_ => panic!("wrong preupdate case"),
|
|
}
|
|
CALLED.store(true, Ordering::Relaxed);
|
|
}));
|
|
|
|
db.execute_batch("UPDATE foo SET t = 'janice'")?;
|
|
assert!(CALLED.load(Ordering::Relaxed));
|
|
Ok(())
|
|
}
|
|
}
|