Merge pull request #1594 from gwenn/wal_hook

Add safe binding to sqlite3_wal_hook
This commit is contained in:
gwenn 2024-11-10 12:28:19 +01:00 committed by GitHub
commit 0af8bfc603
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 68 additions and 3 deletions

View File

@ -7,7 +7,7 @@ use std::ptr;
use crate::ffi;
use crate::{Connection, InnerConnection};
use crate::{Connection, DatabaseName, InnerConnection};
#[cfg(feature = "preupdate_hook")]
pub use preupdate_hook::*;
@ -383,6 +383,39 @@ impl Connection {
self.db.borrow_mut().update_hook(hook);
}
/// Register a callback that is invoked each time data is committed to a database in wal mode.
///
/// A single database handle may have at most a single write-ahead log callback registered at one time.
/// Calling `wal_hook` replaces any previously registered write-ahead log callback.
/// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma
/// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings.
pub fn wal_hook(&self, hook: Option<fn(DatabaseName<'_>, c_int) -> c_int>) {
unsafe extern "C" fn wal_hook_callback(
client_data: *mut c_void,
_db: *mut ffi::sqlite3,
db_name: *const c_char,
pages: c_int,
) -> c_int {
let hook_fn: fn(DatabaseName<'_>, c_int) -> c_int = std::mem::transmute(client_data);
c_int::from(
catch_unwind(|| {
hook_fn(
DatabaseName::from_cstr(std::ffi::CStr::from_ptr(db_name)),
pages,
)
})
.unwrap_or_default(),
)
}
let c = self.db.borrow_mut();
match hook {
Some(f) => unsafe {
ffi::sqlite3_wal_hook(c.db(), Some(wal_hook_callback), f as *mut c_void)
},
None => unsafe { ffi::sqlite3_wal_hook(c.db(), None, ptr::null_mut()) },
};
}
/// Register a query progress callback.
///
/// The parameter `num_ops` is the approximate number of virtual machine
@ -771,7 +804,7 @@ unsafe fn expect_optional_utf8<'a>(
#[cfg(test)]
mod test {
use super::Action;
use crate::{Connection, Result};
use crate::{Connection, DatabaseName, Result};
use std::sync::atomic::{AtomicBool, Ordering};
#[test]
@ -897,4 +930,27 @@ mod test {
Ok(())
}
#[test]
fn wal_hook() -> Result<()> {
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("wal-hook.db3");
let db = Connection::open(&path)?;
let journal_mode: String =
db.pragma_update_and_check(None, "journal_mode", "wal", |row| row.get(0))?;
assert_eq!(journal_mode, "wal");
static CALLED: AtomicBool = AtomicBool::new(false);
db.wal_hook(Some(|db_name, pages| {
assert_eq!(db_name, DatabaseName::Main);
assert!(pages > 0);
CALLED.swap(true, Ordering::Relaxed);
crate::ffi::SQLITE_OK
}));
db.execute_batch("CREATE TABLE x(c);")?;
assert!(CALLED.load(Ordering::Relaxed));
db.wal_hook(None);
Ok(())
}
}

View File

@ -343,7 +343,7 @@ fn path_to_cstring(p: &Path) -> Result<CString> {
}
/// Name for a database within a SQLite connection.
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum DatabaseName<'a> {
/// The main database.
Main,
@ -373,6 +373,15 @@ impl DatabaseName<'_> {
Attached(s) => str_to_cstring(s),
}
}
#[cfg(feature = "hooks")]
pub(crate) fn from_cstr(db_name: &std::ffi::CStr) -> DatabaseName<'_> {
let s = db_name.to_str().expect("illegal database name");
match s {
"main" => DatabaseName::Main,
"temp" => DatabaseName::Temp,
_ => DatabaseName::Attached(s),
}
}
}
/// A connection to a SQLite database.