Merge pull request #1486 from aschey/preupdate_hook

Add preupdate hook
This commit is contained in:
gwenn 2024-03-31 19:14:52 +02:00 committed by GitHub
commit be2689106f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 404 additions and 10 deletions

View File

@ -49,7 +49,7 @@ jobs:
# The `{ sharedKey: ... }` allows different actions to share the cache.
# We're using a `fullBuild` key mostly as a "this needs to do the
# complete" that needs to do the complete build (that is, including
# `--features 'bundled-full session buildtime_bindgen'`), which is very
# `--features 'bundled-full session buildtime_bindgen preupdate_hook'`), which is very
# slow, and has several deps.
- uses: Swatinem/rust-cache@v2
with: { sharedKey: fullBuild }
@ -62,8 +62,8 @@ jobs:
if: matrix.os == 'windows-latest'
run: echo "C:\msys64\mingw64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- run: cargo test --features 'bundled-full session buildtime_bindgen' --all-targets --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen' --doc --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen preupdate_hook' --all-targets --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen preupdate_hook' --doc --workspace --verbose
- name: loadable extension
run: |
@ -119,8 +119,8 @@ jobs:
if: matrix.os == 'windows-latest'
run: echo "C:\msys64\mingw64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- run: cargo test --features 'bundled-full session buildtime_bindgen' --all-targets --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen' --doc --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen preupdate_hook' --all-targets --workspace --verbose
- run: cargo test --features 'bundled-full session buildtime_bindgen preupdate_hook' --doc --workspace --verbose
sqlcipher:
name: Test with sqlcipher
@ -156,7 +156,7 @@ jobs:
# leak sanitization, but we don't care about backtraces here, so long
# as the other tests have them.
RUST_BACKTRACE: "0"
run: cargo -Z build-std test --features 'bundled-full session buildtime_bindgen with-asan' --target x86_64-unknown-linux-gnu
run: cargo -Z build-std test --features 'bundled-full session buildtime_bindgen preupdate_hook with-asan' --target x86_64-unknown-linux-gnu
# Ensure clippy doesn't complain.
clippy:
@ -170,7 +170,7 @@ jobs:
- uses: Swatinem/rust-cache@v2
- run: cargo clippy --all-targets --workspace --features bundled -- -D warnings
# Clippy with all non-conflicting features
- run: cargo clippy --all-targets --workspace --features 'bundled-full session buildtime_bindgen' -- -D warnings
- run: cargo clippy --all-targets --workspace --features 'bundled-full session buildtime_bindgen preupdate_hook' -- -D warnings
# Ensure patch is formatted.
fmt:
@ -192,7 +192,7 @@ jobs:
- uses: hecrj/setup-rust-action@v1
- uses: Swatinem/rust-cache@v2
with: { sharedKey: fullBuild }
- run: cargo doc --features 'bundled-full session buildtime_bindgen' --no-deps
- run: cargo doc --features 'bundled-full session buildtime_bindgen preupdate_hook' --no-deps
env: { RUSTDOCFLAGS: -Dwarnings }
codecov:
@ -210,7 +210,7 @@ jobs:
run: |
cargo test --verbose
cargo test --features="bundled-full" --verbose
cargo test --features="bundled-full session buildtime_bindgen" --verbose
cargo test --features="bundled-full session buildtime_bindgen preupdate_hook" --verbose
cargo test --features="bundled-sqlcipher-vendored-openssl" --verbose
env:
RUSTFLAGS: -Cinstrument-coverage

View File

@ -52,6 +52,7 @@ buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"]
limits = []
loadable_extension = ["libsqlite3-sys/loadable_extension"]
hooks = []
preupdate_hook = ["libsqlite3-sys/preupdate_hook", "hooks"]
i128_blob = []
sqlcipher = ["libsqlite3-sys/sqlcipher"]
unlock_notify = ["libsqlite3-sys/unlock_notify"]

View File

@ -123,6 +123,7 @@ features](https://doc.rust-lang.org/cargo/reference/manifest.html#the-features-s
- As the name implies this depends on the `bundled-sqlcipher` feature, and automatically turns it on.
- If turned on, this uses the [`openssl-sys`](https://crates.io/crates/openssl-sys) crate, with the `vendored` feature enabled in order to build and bundle the OpenSSL crypto library.
* `hooks` for [Commit, Rollback](http://sqlite.org/c3ref/commit_hook.html) and [Data Change](http://sqlite.org/c3ref/update_hook.html) notification callbacks.
* `preupdate_hook` for [preupdate](https://sqlite.org/c3ref/preupdate_count.html) notification callbacks. (Implies `hooks`.)
* `unlock_notify` for [Unlock](https://sqlite.org/unlock_notify.html) notification.
* `vtab` for [virtual table](https://sqlite.org/vtab.html) support (allows you to write virtual table implementations in Rust). Currently, only read-only virtual tables are supported.
* `series` exposes [`generate_series(...)`](https://www.sqlite.org/series.html) Table-Valued Function. (Implies `vtab`.)

View File

@ -9,6 +9,12 @@ use crate::ffi;
use crate::{Connection, InnerConnection};
#[cfg(feature = "preupdate_hook")]
pub use preupdate_hook::*;
#[cfg(feature = "preupdate_hook")]
mod preupdate_hook;
/// Action Codes
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[repr(i32)]

372
src/hooks/preupdate_hook.rs Normal file
View File

@ -0,0 +1,372 @@
use std::fmt::Debug;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr;
use super::expect_utf8;
use super::free_boxed_hook;
use super::Action;
use crate::error::check;
use crate::ffi;
use crate::inner_connection::InnerConnection;
use crate::types::ValueRef;
use crate::Connection;
use crate::Result;
/// The possible cases for when a PreUpdateHook gets triggered. Allows access to the relevant
/// functions for each case through the contained values.
#[derive(Debug)]
pub enum PreUpdateCase {
/// Pre-update hook was triggered by an insert.
Insert(PreUpdateNewValueAccessor),
/// Pre-update hook was triggered by a delete.
Delete(PreUpdateOldValueAccessor),
/// Pre-update hook was triggered by an update.
Update {
#[allow(missing_docs)]
old_value_accessor: PreUpdateOldValueAccessor,
#[allow(missing_docs)]
new_value_accessor: PreUpdateNewValueAccessor,
},
/// This variant is not normally produced by SQLite. You may encounter it
/// if you're using a different version than what's supported by this library.
Unknown,
}
impl From<PreUpdateCase> for Action {
fn from(puc: PreUpdateCase) -> Action {
match puc {
PreUpdateCase::Insert(_) => Action::SQLITE_INSERT,
PreUpdateCase::Delete(_) => Action::SQLITE_DELETE,
PreUpdateCase::Update { .. } => Action::SQLITE_UPDATE,
PreUpdateCase::Unknown => Action::UNKNOWN,
}
}
}
/// An accessor to access the old values of the row being deleted/updated during the preupdate callback.
#[derive(Debug)]
pub struct PreUpdateOldValueAccessor {
db: *mut ffi::sqlite3,
old_row_id: i64,
}
impl PreUpdateOldValueAccessor {
/// Get the amount of columns in the row being deleted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { ffi::sqlite3_preupdate_count(self.db) }
}
/// Get the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { ffi::sqlite3_preupdate_depth(self.db) }
}
/// Get the row id of the row being updated/deleted.
pub fn get_old_row_id(&self) -> i64 {
self.old_row_id
}
/// Get the value of the row being updated/deleted at the specified index.
pub fn get_old_column_value(&self, i: i32) -> Result<ValueRef> {
let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
unsafe {
check(ffi::sqlite3_preupdate_old(self.db, i, &mut p_value))?;
Ok(ValueRef::from_value(p_value))
}
}
}
/// An accessor to access the new values of the row being inserted/updated
/// during the preupdate callback.
#[derive(Debug)]
pub struct PreUpdateNewValueAccessor {
db: *mut ffi::sqlite3,
new_row_id: i64,
}
impl PreUpdateNewValueAccessor {
/// Get the amount of columns in the row being inserted/updated.
pub fn get_column_count(&self) -> i32 {
unsafe { ffi::sqlite3_preupdate_count(self.db) }
}
/// Get the depth of the query that triggered the preupdate hook.
/// Returns 0 if the preupdate callback was invoked as a result of
/// a direct insert, update, or delete operation;
/// 1 for inserts, updates, or deletes invoked by top-level triggers;
/// 2 for changes resulting from triggers called by top-level triggers; and so forth.
pub fn get_query_depth(&self) -> i32 {
unsafe { ffi::sqlite3_preupdate_depth(self.db) }
}
/// Get the row id of the row being inserted/updated.
pub fn get_new_row_id(&self) -> i64 {
self.new_row_id
}
/// Get the value of the row being updated/deleted at the specified index.
pub fn get_new_column_value(&self, i: i32) -> Result<ValueRef> {
let mut p_value: *mut ffi::sqlite3_value = ptr::null_mut();
unsafe {
check(ffi::sqlite3_preupdate_new(self.db, i, &mut p_value))?;
Ok(ValueRef::from_value(p_value))
}
}
}
impl Connection {
/// Register a callback function to be invoked before
/// a row is updated, inserted or deleted.
///
/// The callback parameters are:
///
/// - the name of the database ("main", "temp", ...),
/// - the name of the table that is updated,
/// - a variant of the PreUpdateCase enum which allows access to extra functions depending
/// on whether it's an update, delete or insert.
#[inline]
pub fn preupdate_hook<F>(&self, hook: Option<F>)
where
F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
{
self.db.borrow_mut().preupdate_hook(hook);
}
}
impl InnerConnection {
#[inline]
pub fn remove_preupdate_hook(&mut self) {
self.preupdate_hook(None::<fn(Action, &str, &str, &PreUpdateCase)>);
}
/// ```compile_fail
/// use rusqlite::{Connection, Result, hooks::PreUpdateCase};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// }));
/// }
/// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
/// }
/// ```
fn preupdate_hook<F>(&mut self, hook: Option<F>)
where
F: FnMut(Action, &str, &str, &PreUpdateCase) + Send + 'static,
{
unsafe extern "C" fn call_boxed_closure<F>(
p_arg: *mut c_void,
sqlite: *mut ffi::sqlite3,
action_code: c_int,
db_name: *const c_char,
tbl_name: *const c_char,
old_row_id: i64,
new_row_id: i64,
) where
F: FnMut(Action, &str, &str, &PreUpdateCase),
{
let action = Action::from(action_code);
let preupdate_case = match action {
Action::SQLITE_INSERT => PreUpdateCase::Insert(PreUpdateNewValueAccessor {
db: sqlite,
new_row_id,
}),
Action::SQLITE_DELETE => PreUpdateCase::Delete(PreUpdateOldValueAccessor {
db: sqlite,
old_row_id,
}),
Action::SQLITE_UPDATE => PreUpdateCase::Update {
old_value_accessor: PreUpdateOldValueAccessor {
db: sqlite,
old_row_id,
},
new_value_accessor: PreUpdateNewValueAccessor {
db: sqlite,
new_row_id,
},
},
Action::UNKNOWN => PreUpdateCase::Unknown,
};
drop(catch_unwind(|| {
let boxed_hook: *mut F = p_arg.cast::<F>();
(*boxed_hook)(
action,
expect_utf8(db_name, "database name"),
expect_utf8(tbl_name, "table name"),
&preupdate_case,
);
}));
}
let free_preupdate_hook = if hook.is_some() {
Some(free_boxed_hook::<F> as unsafe fn(*mut c_void))
} else {
None
};
let previous_hook = match hook {
Some(hook) => {
let boxed_hook: *mut F = Box::into_raw(Box::new(hook));
unsafe {
ffi::sqlite3_preupdate_hook(
self.db(),
Some(call_boxed_closure::<F>),
boxed_hook.cast(),
)
}
}
_ => unsafe { ffi::sqlite3_preupdate_hook(self.db(), None, ptr::null_mut()) },
};
if !previous_hook.is_null() {
if let Some(free_boxed_hook) = self.free_preupdate_hook {
unsafe { free_boxed_hook(previous_hook) };
}
}
self.free_preupdate_hook = free_preupdate_hook;
}
}
#[cfg(test)]
mod test {
use std::sync::atomic::{AtomicBool, Ordering};
use super::super::Action;
use super::PreUpdateCase;
use crate::{Connection, Result};
#[test]
fn test_preupdate_hook_insert() -> Result<()> {
let db = Connection::open_in_memory()?;
static CALLED: AtomicBool = AtomicBool::new(false);
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
assert_eq!(Action::SQLITE_INSERT, action);
assert_eq!("main", db);
assert_eq!("foo", tbl);
match case {
PreUpdateCase::Insert(accessor) => {
assert_eq!(1, accessor.get_column_count());
assert_eq!(1, accessor.get_new_row_id());
assert_eq!(0, accessor.get_query_depth());
// out of bounds access should return an error
assert!(accessor.get_new_column_value(1).is_err());
assert_eq!(
"lisa",
accessor.get_new_column_value(0).unwrap().as_str().unwrap()
);
assert_eq!(0, accessor.get_query_depth());
}
_ => panic!("wrong preupdate case"),
}
CALLED.store(true, Ordering::Relaxed);
}));
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
assert!(CALLED.load(Ordering::Relaxed));
Ok(())
}
#[test]
fn test_preupdate_hook_delete() -> Result<()> {
let db = Connection::open_in_memory()?;
static CALLED: AtomicBool = AtomicBool::new(false);
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
assert_eq!(Action::SQLITE_DELETE, action);
assert_eq!("main", db);
assert_eq!("foo", tbl);
match case {
PreUpdateCase::Delete(accessor) => {
assert_eq!(1, accessor.get_column_count());
assert_eq!(1, accessor.get_old_row_id());
assert_eq!(0, accessor.get_query_depth());
// out of bounds access should return an error
assert!(accessor.get_old_column_value(1).is_err());
assert_eq!(
"lisa",
accessor.get_old_column_value(0).unwrap().as_str().unwrap()
);
assert_eq!(0, accessor.get_query_depth());
}
_ => panic!("wrong preupdate case"),
}
CALLED.store(true, Ordering::Relaxed);
}));
db.execute_batch("DELETE from foo")?;
assert!(CALLED.load(Ordering::Relaxed));
Ok(())
}
#[test]
fn test_preupdate_hook_update() -> Result<()> {
let db = Connection::open_in_memory()?;
static CALLED: AtomicBool = AtomicBool::new(false);
db.execute_batch("CREATE TABLE foo (t TEXT)")?;
db.execute_batch("INSERT INTO foo VALUES ('lisa')")?;
db.preupdate_hook(Some(|action, db: &str, tbl: &str, case: &PreUpdateCase| {
assert_eq!(Action::SQLITE_UPDATE, action);
assert_eq!("main", db);
assert_eq!("foo", tbl);
match case {
PreUpdateCase::Update {
old_value_accessor,
new_value_accessor,
} => {
assert_eq!(1, old_value_accessor.get_column_count());
assert_eq!(1, old_value_accessor.get_old_row_id());
assert_eq!(0, old_value_accessor.get_query_depth());
// out of bounds access should return an error
assert!(old_value_accessor.get_old_column_value(1).is_err());
assert_eq!(
"lisa",
old_value_accessor
.get_old_column_value(0)
.unwrap()
.as_str()
.unwrap()
);
assert_eq!(0, old_value_accessor.get_query_depth());
assert_eq!(1, new_value_accessor.get_column_count());
assert_eq!(1, new_value_accessor.get_new_row_id());
assert_eq!(0, new_value_accessor.get_query_depth());
// out of bounds access should return an error
assert!(new_value_accessor.get_new_column_value(1).is_err());
assert_eq!(
"janice",
new_value_accessor
.get_new_column_value(0)
.unwrap()
.as_str()
.unwrap()
);
assert_eq!(0, new_value_accessor.get_query_depth());
}
_ => panic!("wrong preupdate case"),
}
CALLED.store(true, Ordering::Relaxed);
}));
db.execute_batch("UPDATE foo SET t = 'janice'")?;
assert!(CALLED.load(Ordering::Relaxed));
Ok(())
}
}

View File

@ -33,6 +33,8 @@ pub struct InnerConnection {
pub progress_handler: Option<Box<dyn FnMut() -> bool + Send>>,
#[cfg(feature = "hooks")]
pub authorizer: Option<crate::hooks::BoxedAuthorizer>,
#[cfg(feature = "preupdate_hook")]
pub free_preupdate_hook: Option<unsafe fn(*mut ::std::os::raw::c_void)>,
owned: bool,
}
@ -55,6 +57,8 @@ impl InnerConnection {
progress_handler: None,
#[cfg(feature = "hooks")]
authorizer: None,
#[cfg(feature = "preupdate_hook")]
free_preupdate_hook: None,
owned,
}
}
@ -148,6 +152,7 @@ impl InnerConnection {
return Ok(());
}
self.remove_hooks();
self.remove_preupdate_hook();
let mut shared_handle = self.interrupt_lock.lock().unwrap();
assert!(
!shared_handle.is_null(),
@ -337,6 +342,10 @@ impl InnerConnection {
#[inline]
fn remove_hooks(&mut self) {}
#[cfg(not(feature = "preupdate_hook"))]
#[inline]
fn remove_preupdate_hook(&mut self) {}
pub fn db_readonly(&self, db_name: super::DatabaseName<'_>) -> Result<bool> {
let name = db_name.as_cstring()?;
let r = unsafe { ffi::sqlite3_db_readonly(self.db, name.as_ptr()) };

View File

@ -207,7 +207,12 @@ where
}
}
#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))]
#[cfg(any(
feature = "functions",
feature = "session",
feature = "vtab",
feature = "preupdate_hook"
))]
impl<'a> ValueRef<'a> {
pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> {
use crate::ffi;