From 506ecb91d039e34ff9a9e3672fda70c6d018b05a Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 10 Nov 2024 11:54:57 +0100 Subject: [PATCH] Add safe binding to sqlite3_wal_hook --- src/hooks/mod.rs | 60 ++++++++++++++++++++++++++++++++++++++++++++++-- src/lib.rs | 10 +++++++- 2 files changed, 67 insertions(+), 3 deletions(-) diff --git a/src/hooks/mod.rs b/src/hooks/mod.rs index be49bc1..9d14aad 100644 --- a/src/hooks/mod.rs +++ b/src/hooks/mod.rs @@ -7,7 +7,7 @@ use std::ptr; use crate::ffi; -use crate::{Connection, InnerConnection}; +use crate::{Connection, DatabaseName, InnerConnection}; #[cfg(feature = "preupdate_hook")] pub use preupdate_hook::*; @@ -383,6 +383,39 @@ impl Connection { self.db.borrow_mut().update_hook(hook); } + /// Register a callback that is invoked each time data is committed to a database in wal mode. + /// + /// A single database handle may have at most a single write-ahead log callback registered at one time. + /// Calling `wal_hook` replaces any previously registered write-ahead log callback. + /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma + /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings. + pub fn wal_hook(&self, hook: Option, c_int) -> c_int>) { + unsafe extern "C" fn wal_hook_callback( + client_data: *mut c_void, + _db: *mut ffi::sqlite3, + db_name: *const c_char, + pages: c_int, + ) -> c_int { + let hook_fn: fn(DatabaseName<'_>, c_int) -> c_int = std::mem::transmute(client_data); + c_int::from( + catch_unwind(|| { + hook_fn( + DatabaseName::from_cstr(std::ffi::CStr::from_ptr(db_name)), + pages, + ) + }) + .unwrap_or_default(), + ) + } + let c = self.db.borrow_mut(); + match hook { + Some(f) => unsafe { + ffi::sqlite3_wal_hook(c.db(), Some(wal_hook_callback), f as *mut c_void) + }, + None => unsafe { ffi::sqlite3_wal_hook(c.db(), None, ptr::null_mut()) }, + }; + } + /// Register a query progress callback. /// /// The parameter `num_ops` is the approximate number of virtual machine @@ -771,7 +804,7 @@ unsafe fn expect_optional_utf8<'a>( #[cfg(test)] mod test { use super::Action; - use crate::{Connection, Result}; + use crate::{Connection, DatabaseName, Result}; use std::sync::atomic::{AtomicBool, Ordering}; #[test] @@ -897,4 +930,27 @@ mod test { Ok(()) } + + #[test] + fn wal_hook() -> Result<()> { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("wal-hook.db3"); + + let db = Connection::open(&path)?; + let journal_mode: String = + db.pragma_update_and_check(None, "journal_mode", "wal", |row| row.get(0))?; + assert_eq!(journal_mode, "wal"); + + static CALLED: AtomicBool = AtomicBool::new(false); + db.wal_hook(Some(|db_name, pages| { + assert_eq!(db_name, DatabaseName::Main); + assert!(pages > 0); + CALLED.swap(true, Ordering::Relaxed); + crate::ffi::SQLITE_OK + })); + db.execute_batch("CREATE TABLE x(c);")?; + assert!(CALLED.load(Ordering::Relaxed)); + db.wal_hook(None); + Ok(()) + } } diff --git a/src/lib.rs b/src/lib.rs index f47425c..48ce603 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -343,7 +343,7 @@ fn path_to_cstring(p: &Path) -> Result { } /// Name for a database within a SQLite connection. -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum DatabaseName<'a> { /// The main database. Main, @@ -373,6 +373,14 @@ impl DatabaseName<'_> { Attached(s) => str_to_cstring(s), } } + pub(crate) fn from_cstr(db_name: &std::ffi::CStr) -> DatabaseName<'_> { + let s = db_name.to_str().expect("illegal database name"); + match s { + "main" => DatabaseName::Main, + "temp" => DatabaseName::Temp, + _ => DatabaseName::Attached(s), + } + } } /// A connection to a SQLite database.