Merge pull request #1595 from gwenn/wal_hook

Make possible to checkpoint a database from `wal_hook`
This commit is contained in:
gwenn 2024-11-10 20:35:53 +01:00 committed by GitHub
commit 8266933aa3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 89 additions and 28 deletions

View File

@ -422,6 +422,14 @@ pub unsafe fn error_from_handle(db: *mut ffi::sqlite3, code: c_int) -> Error {
error_from_sqlite_code(code, message) error_from_sqlite_code(code, message)
} }
pub unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> {
if code == ffi::SQLITE_OK {
Ok(())
} else {
Err(error_from_handle(db, code))
}
}
#[cold] #[cold]
#[cfg(not(feature = "modern_sqlite"))] // SQLite >= 3.38.0 #[cfg(not(feature = "modern_sqlite"))] // SQLite >= 3.38.0
pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, _sql: &str) -> Error { pub unsafe fn error_with_offset(db: *mut ffi::sqlite3, code: c_int, _sql: &str) -> Error {

View File

@ -7,7 +7,7 @@ use std::ptr;
use crate::ffi; use crate::ffi;
use crate::{Connection, DatabaseName, InnerConnection}; use crate::{error::decode_result_raw, Connection, DatabaseName, InnerConnection, Result};
#[cfg(feature = "preupdate_hook")] #[cfg(feature = "preupdate_hook")]
pub use preupdate_hook::*; pub use preupdate_hook::*;
@ -389,23 +389,22 @@ impl Connection {
/// Calling `wal_hook` replaces any previously registered write-ahead log callback. /// Calling `wal_hook` replaces any previously registered write-ahead log callback.
/// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma
/// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings. /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings.
pub fn wal_hook(&self, hook: Option<fn(DatabaseName<'_>, c_int) -> c_int>) { pub fn wal_hook(&self, hook: Option<fn(&Wal, c_int) -> Result<()>>) {
unsafe extern "C" fn wal_hook_callback( unsafe extern "C" fn wal_hook_callback(
client_data: *mut c_void, client_data: *mut c_void,
_db: *mut ffi::sqlite3, db: *mut ffi::sqlite3,
db_name: *const c_char, db_name: *const c_char,
pages: c_int, pages: c_int,
) -> c_int { ) -> c_int {
let hook_fn: fn(DatabaseName<'_>, c_int) -> c_int = std::mem::transmute(client_data); let hook_fn: fn(&Wal, c_int) -> Result<()> = std::mem::transmute(client_data);
c_int::from( let wal = Wal { db, db_name };
catch_unwind(|| { catch_unwind(|| match hook_fn(&wal, pages) {
hook_fn( Ok(_) => ffi::SQLITE_OK,
DatabaseName::from_cstr(std::ffi::CStr::from_ptr(db_name)), Err(e) => e
pages, .sqlite_error()
) .map_or(ffi::SQLITE_ERROR, |x| x.extended_code),
}) })
.unwrap_or_default(), .unwrap_or_default()
)
} }
let c = self.db.borrow_mut(); let c = self.db.borrow_mut();
match hook { match hook {
@ -442,6 +441,57 @@ impl Connection {
} }
} }
/// Checkpoint mode
#[derive(Clone, Copy)]
#[repr(i32)]
#[non_exhaustive]
pub enum CheckpointMode {
/// Do as much as possible w/o blocking
PASSIVE = ffi::SQLITE_CHECKPOINT_PASSIVE,
/// Wait for writers, then checkpoint
FULL = ffi::SQLITE_CHECKPOINT_FULL,
/// Like FULL but wait for readers
RESTART = ffi::SQLITE_CHECKPOINT_RESTART,
/// Like RESTART but also truncate WA
TRUNCATE = ffi::SQLITE_CHECKPOINT_TRUNCATE,
}
/// Write-Ahead Log
pub struct Wal {
db: *mut ffi::sqlite3,
db_name: *const c_char,
}
impl Wal {
/// Checkpoint a database
pub fn checkpoint(&self) -> Result<()> {
unsafe { decode_result_raw(self.db, ffi::sqlite3_wal_checkpoint(self.db, self.db_name)) }
}
/// Checkpoint a database
pub fn checkpoint_v2(&self, mode: CheckpointMode) -> Result<(c_int, c_int)> {
let mut n_log = 0;
let mut n_ckpt = 0;
unsafe {
decode_result_raw(
self.db,
ffi::sqlite3_wal_checkpoint_v2(
self.db,
self.db_name,
mode as c_int,
&mut n_log,
&mut n_ckpt,
),
)?
};
Ok((n_log, n_ckpt))
}
/// Name of the database that was written to
pub fn name(&self) -> DatabaseName<'_> {
DatabaseName::from_cstr(unsafe { std::ffi::CStr::from_ptr(self.db_name) })
}
}
impl InnerConnection { impl InnerConnection {
#[inline] #[inline]
pub fn remove_hooks(&mut self) { pub fn remove_hooks(&mut self) {
@ -942,14 +992,24 @@ mod test {
assert_eq!(journal_mode, "wal"); assert_eq!(journal_mode, "wal");
static CALLED: AtomicBool = AtomicBool::new(false); static CALLED: AtomicBool = AtomicBool::new(false);
db.wal_hook(Some(|db_name, pages| { db.wal_hook(Some(|wal, pages| {
assert_eq!(db_name, DatabaseName::Main); assert_eq!(wal.name(), DatabaseName::Main);
assert!(pages > 0); assert!(pages > 0);
CALLED.swap(true, Ordering::Relaxed); CALLED.swap(true, Ordering::Relaxed);
crate::ffi::SQLITE_OK wal.checkpoint()
})); }));
db.execute_batch("CREATE TABLE x(c);")?; db.execute_batch("CREATE TABLE x(c);")?;
assert!(CALLED.load(Ordering::Relaxed)); assert!(CALLED.load(Ordering::Relaxed));
db.wal_hook(Some(|wal, pages| {
assert!(pages > 0);
let (log, ckpt) = wal.checkpoint_v2(super::CheckpointMode::TRUNCATE)?;
assert_eq!(log, 0);
assert_eq!(ckpt, 0);
Ok(())
}));
db.execute_batch("CREATE TABLE y(c);")?;
db.wal_hook(None); db.wal_hook(None);
Ok(()) Ok(())
} }

View File

@ -9,7 +9,9 @@ use std::sync::{Arc, Mutex};
use super::ffi; use super::ffi;
use super::str_for_sqlite; use super::str_for_sqlite;
use super::{Connection, InterruptHandle, OpenFlags, PrepFlags, Result}; use super::{Connection, InterruptHandle, OpenFlags, PrepFlags, Result};
use crate::error::{error_from_handle, error_from_sqlite_code, error_with_offset, Error}; use crate::error::{
decode_result_raw, error_from_handle, error_from_sqlite_code, error_with_offset, Error,
};
use crate::raw_statement::RawStatement; use crate::raw_statement::RawStatement;
use crate::statement::Statement; use crate::statement::Statement;
use crate::version_number; use crate::version_number;
@ -134,16 +136,7 @@ impl InnerConnection {
#[inline] #[inline]
pub fn decode_result(&self, code: c_int) -> Result<()> { pub fn decode_result(&self, code: c_int) -> Result<()> {
unsafe { Self::decode_result_raw(self.db(), code) } unsafe { decode_result_raw(self.db(), code) }
}
#[inline]
unsafe fn decode_result_raw(db: *mut ffi::sqlite3, code: c_int) -> Result<()> {
if code == ffi::SQLITE_OK {
Ok(())
} else {
Err(error_from_handle(db, code))
}
} }
pub fn close(&mut self) -> Result<()> { pub fn close(&mut self) -> Result<()> {
@ -165,7 +158,7 @@ impl InnerConnection {
let r = ffi::sqlite3_close(self.db); let r = ffi::sqlite3_close(self.db);
// Need to use _raw because _guard has a reference out, and // Need to use _raw because _guard has a reference out, and
// decode_result takes &mut self. // decode_result takes &mut self.
let r = Self::decode_result_raw(self.db, r); let r = decode_result_raw(self.db, r);
if r.is_ok() { if r.is_ok() {
*shared_handle = ptr::null_mut(); *shared_handle = ptr::null_mut();
self.db = ptr::null_mut(); self.db = ptr::null_mut();