diff --git a/src/error.rs b/src/error.rs index a256797..bb46e2d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -422,6 +422,14 @@ pub unsafe fn error_from_handle(db: *mut ffi::sqlite3, code: c_int) -> Error { error_from_sqlite_code(code, message) } +pub unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + Err(error_from_handle(db, code)) + } +} + #[cold] #[cfg(not(feature = "modern_sqlite"))] // SQLite >= 3.38.0 pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, _sql: &str) -> Error { diff --git a/src/hooks/mod.rs b/src/hooks/mod.rs index 9d14aad..8e5f951 100644 --- a/src/hooks/mod.rs +++ b/src/hooks/mod.rs @@ -7,7 +7,7 @@ use std::ptr; use crate::ffi; -use crate::{Connection, DatabaseName, InnerConnection}; +use crate::{error::decode_result_raw, Connection, DatabaseName, InnerConnection, Result}; #[cfg(feature = "preupdate_hook")] pub use preupdate_hook::*; @@ -389,23 +389,22 @@ impl Connection { /// Calling `wal_hook` replaces any previously registered write-ahead log callback. /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings. - pub fn wal_hook(&self, hook: Option, c_int) -> c_int>) { + pub fn wal_hook(&self, hook: Option Result<()>>) { unsafe extern "C" fn wal_hook_callback( client_data: *mut c_void, - _db: *mut ffi::sqlite3, + db: *mut ffi::sqlite3, db_name: *const c_char, pages: c_int, ) -> c_int { - let hook_fn: fn(DatabaseName<'_>, c_int) -> c_int = std::mem::transmute(client_data); - c_int::from( - catch_unwind(|| { - hook_fn( - DatabaseName::from_cstr(std::ffi::CStr::from_ptr(db_name)), - pages, - ) - }) - .unwrap_or_default(), - ) + let hook_fn: fn(&Wal, c_int) -> Result<()> = std::mem::transmute(client_data); + let wal = Wal { db, db_name }; + catch_unwind(|| match hook_fn(&wal, pages) { + Ok(_) => ffi::SQLITE_OK, + Err(e) => e + .sqlite_error() + .map_or(ffi::SQLITE_ERROR, |x| x.extended_code), + }) + .unwrap_or_default() } let c = self.db.borrow_mut(); match hook { @@ -442,6 +441,57 @@ impl Connection { } } +/// Checkpoint mode +#[derive(Clone, Copy)] +#[repr(i32)] +#[non_exhaustive] +pub enum CheckpointMode { + /// Do as much as possible w/o blocking + PASSIVE = ffi::SQLITE_CHECKPOINT_PASSIVE, + /// Wait for writers, then checkpoint + FULL = ffi::SQLITE_CHECKPOINT_FULL, + /// Like FULL but wait for readers + RESTART = ffi::SQLITE_CHECKPOINT_RESTART, + /// Like RESTART but also truncate WA + TRUNCATE = ffi::SQLITE_CHECKPOINT_TRUNCATE, +} + +/// Write-Ahead Log +pub struct Wal { + db: *mut ffi::sqlite3, + db_name: *const c_char, +} + +impl Wal { + /// Checkpoint a database + pub fn checkpoint(&self) -> Result<()> { + unsafe { decode_result_raw(self.db, ffi::sqlite3_wal_checkpoint(self.db, self.db_name)) } + } + /// Checkpoint a database + pub fn checkpoint_v2(&self, mode: CheckpointMode) -> Result<(c_int, c_int)> { + let mut n_log = 0; + let mut n_ckpt = 0; + unsafe { + decode_result_raw( + self.db, + ffi::sqlite3_wal_checkpoint_v2( + self.db, + self.db_name, + mode as c_int, + &mut n_log, + &mut n_ckpt, + ), + )? + }; + Ok((n_log, n_ckpt)) + } + + /// Name of the database that was written to + pub fn name(&self) -> DatabaseName<'_> { + DatabaseName::from_cstr(unsafe { std::ffi::CStr::from_ptr(self.db_name) }) + } +} + impl InnerConnection { #[inline] pub fn remove_hooks(&mut self) { @@ -942,14 +992,24 @@ mod test { assert_eq!(journal_mode, "wal"); static CALLED: AtomicBool = AtomicBool::new(false); - db.wal_hook(Some(|db_name, pages| { - assert_eq!(db_name, DatabaseName::Main); + db.wal_hook(Some(|wal, pages| { + assert_eq!(wal.name(), DatabaseName::Main); assert!(pages > 0); CALLED.swap(true, Ordering::Relaxed); - crate::ffi::SQLITE_OK + wal.checkpoint() })); db.execute_batch("CREATE TABLE x(c);")?; assert!(CALLED.load(Ordering::Relaxed)); + + db.wal_hook(Some(|wal, pages| { + assert!(pages > 0); + let (log, ckpt) = wal.checkpoint_v2(super::CheckpointMode::TRUNCATE)?; + assert_eq!(log, 0); + assert_eq!(ckpt, 0); + Ok(()) + })); + db.execute_batch("CREATE TABLE y(c);")?; + db.wal_hook(None); Ok(()) } diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 1c08899..bc7f72f 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -9,7 +9,9 @@ use std::sync::{Arc, Mutex}; use super::ffi; use super::str_for_sqlite; use super::{Connection, InterruptHandle, OpenFlags, PrepFlags, Result}; -use crate::error::{error_from_handle, error_from_sqlite_code, error_with_offset, Error}; +use crate::error::{ + decode_result_raw, error_from_handle, error_from_sqlite_code, error_with_offset, Error, +}; use crate::raw_statement::RawStatement; use crate::statement::Statement; use crate::version_number; @@ -134,16 +136,7 @@ impl InnerConnection { #[inline] pub fn decode_result(&self, code: c_int) -> Result<()> { - unsafe { Self::decode_result_raw(self.db(), code) } - } - - #[inline] - unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> { - if code == ffi::SQLITE_OK { - Ok(()) - } else { - Err(error_from_handle(db, code)) - } + unsafe { decode_result_raw(self.db(), code) } } pub fn close(&mut self) -> Result<()> { @@ -165,7 +158,7 @@ impl InnerConnection { let r = ffi::sqlite3_close(self.db); // Need to use _raw because _guard has a reference out, and // decode_result takes &mut self. - let r = Self::decode_result_raw(self.db, r); + let r = decode_result_raw(self.db, r); if r.is_ok() { *shared_handle = ptr::null_mut(); self.db = ptr::null_mut();