mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-31 05:48:56 +08:00 
			
		
		
		
	Merge pull request #1594 from gwenn/wal_hook
Add safe binding to sqlite3_wal_hook
This commit is contained in:
		| @@ -7,7 +7,7 @@ use std::ptr; | ||||
|  | ||||
| use crate::ffi; | ||||
|  | ||||
| use crate::{Connection, InnerConnection}; | ||||
| use crate::{Connection, DatabaseName, InnerConnection}; | ||||
|  | ||||
| #[cfg(feature = "preupdate_hook")] | ||||
| pub use preupdate_hook::*; | ||||
| @@ -383,6 +383,39 @@ impl Connection { | ||||
|         self.db.borrow_mut().update_hook(hook); | ||||
|     } | ||||
|  | ||||
|     /// Register a callback that is invoked each time data is committed to a database in wal mode. | ||||
|     /// | ||||
|     /// A single database handle may have at most a single write-ahead log callback registered at one time. | ||||
|     /// Calling `wal_hook` replaces any previously registered write-ahead log callback. | ||||
|     /// Note that the `sqlite3_wal_autocheckpoint()` interface and the `wal_autocheckpoint` pragma | ||||
|     /// both invoke `sqlite3_wal_hook()` and will overwrite any prior `sqlite3_wal_hook()` settings. | ||||
|     pub fn wal_hook(&self, hook: Option<fn(DatabaseName<'_>, c_int) -> c_int>) { | ||||
|         unsafe extern "C" fn wal_hook_callback( | ||||
|             client_data: *mut c_void, | ||||
|             _db: *mut ffi::sqlite3, | ||||
|             db_name: *const c_char, | ||||
|             pages: c_int, | ||||
|         ) -> c_int { | ||||
|             let hook_fn: fn(DatabaseName<'_>, c_int) -> c_int = std::mem::transmute(client_data); | ||||
|             c_int::from( | ||||
|                 catch_unwind(|| { | ||||
|                     hook_fn( | ||||
|                         DatabaseName::from_cstr(std::ffi::CStr::from_ptr(db_name)), | ||||
|                         pages, | ||||
|                     ) | ||||
|                 }) | ||||
|                 .unwrap_or_default(), | ||||
|             ) | ||||
|         } | ||||
|         let c = self.db.borrow_mut(); | ||||
|         match hook { | ||||
|             Some(f) => unsafe { | ||||
|                 ffi::sqlite3_wal_hook(c.db(), Some(wal_hook_callback), f as *mut c_void) | ||||
|             }, | ||||
|             None => unsafe { ffi::sqlite3_wal_hook(c.db(), None, ptr::null_mut()) }, | ||||
|         }; | ||||
|     } | ||||
|  | ||||
|     /// Register a query progress callback. | ||||
|     /// | ||||
|     /// The parameter `num_ops` is the approximate number of virtual machine | ||||
| @@ -771,7 +804,7 @@ unsafe fn expect_optional_utf8<'a>( | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|     use super::Action; | ||||
|     use crate::{Connection, Result}; | ||||
|     use crate::{Connection, DatabaseName, Result}; | ||||
|     use std::sync::atomic::{AtomicBool, Ordering}; | ||||
|  | ||||
|     #[test] | ||||
| @@ -897,4 +930,27 @@ mod test { | ||||
|  | ||||
|         Ok(()) | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn wal_hook() -> Result<()> { | ||||
|         let temp_dir = tempfile::tempdir().unwrap(); | ||||
|         let path = temp_dir.path().join("wal-hook.db3"); | ||||
|  | ||||
|         let db = Connection::open(&path)?; | ||||
|         let journal_mode: String = | ||||
|             db.pragma_update_and_check(None, "journal_mode", "wal", |row| row.get(0))?; | ||||
|         assert_eq!(journal_mode, "wal"); | ||||
|  | ||||
|         static CALLED: AtomicBool = AtomicBool::new(false); | ||||
|         db.wal_hook(Some(|db_name, pages| { | ||||
|             assert_eq!(db_name, DatabaseName::Main); | ||||
|             assert!(pages > 0); | ||||
|             CALLED.swap(true, Ordering::Relaxed); | ||||
|             crate::ffi::SQLITE_OK | ||||
|         })); | ||||
|         db.execute_batch("CREATE TABLE x(c);")?; | ||||
|         assert!(CALLED.load(Ordering::Relaxed)); | ||||
|         db.wal_hook(None); | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
|   | ||||
							
								
								
									
										11
									
								
								src/lib.rs
									
									
									
									
									
								
							
							
						
						
									
										11
									
								
								src/lib.rs
									
									
									
									
									
								
							| @@ -343,7 +343,7 @@ fn path_to_cstring(p: &Path) -> Result<CString> { | ||||
| } | ||||
|  | ||||
| /// Name for a database within a SQLite connection. | ||||
| #[derive(Copy, Clone, Debug)] | ||||
| #[derive(Copy, Clone, Debug, PartialEq, Eq)] | ||||
| pub enum DatabaseName<'a> { | ||||
|     /// The main database. | ||||
|     Main, | ||||
| @@ -373,6 +373,15 @@ impl DatabaseName<'_> { | ||||
|             Attached(s) => str_to_cstring(s), | ||||
|         } | ||||
|     } | ||||
|     #[cfg(feature = "hooks")] | ||||
|     pub(crate) fn from_cstr(db_name: &std::ffi::CStr) -> DatabaseName<'_> { | ||||
|         let s = db_name.to_str().expect("illegal database name"); | ||||
|         match s { | ||||
|             "main" => DatabaseName::Main, | ||||
|             "temp" => DatabaseName::Temp, | ||||
|             _ => DatabaseName::Attached(s), | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// A connection to a SQLite database. | ||||
|   | ||||
		Reference in New Issue
	
	Block a user