diff --git a/.travis.yml b/.travis.yml index de9a910..d3aa772 100644 --- a/.travis.yml +++ b/.travis.yml @@ -38,6 +38,7 @@ script: - cargo test --features serde_json - cargo test --features bundled - cargo test --features sqlcipher + - cargo test --features "unlock_notify bundled" - cargo test --features "backup blob chrono functions hooks limits load_extension serde_json trace" - cargo test --features "backup blob chrono functions hooks limits load_extension serde_json trace buildtime_bindgen" - cargo test --features "backup blob chrono functions hooks limits load_extension serde_json trace bundled" diff --git a/Cargo.toml b/Cargo.toml index fdf88bc..1befcc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,7 @@ buildtime_bindgen = ["libsqlite3-sys/buildtime_bindgen"] limits = [] hooks = [] sqlcipher = ["libsqlite3-sys/sqlcipher"] +unlock_notify = ["libsqlite3-sys/unlock_notify"] [dependencies] time = "0.1.0" diff --git a/libsqlite3-sys/Cargo.toml b/libsqlite3-sys/Cargo.toml index 64e3f1e..0378f5b 100644 --- a/libsqlite3-sys/Cargo.toml +++ b/libsqlite3-sys/Cargo.toml @@ -21,6 +21,8 @@ min_sqlite_version_3_6_23 = ["pkg-config", "vcpkg"] min_sqlite_version_3_7_3 = ["pkg-config", "vcpkg"] min_sqlite_version_3_7_4 = ["pkg-config", "vcpkg"] min_sqlite_version_3_7_16 = ["pkg-config", "vcpkg"] +# sqlite3_unlock_notify >= 3.6.12 +unlock_notify = [] [build-dependencies] bindgen = { version = "0.36", optional = true } diff --git a/libsqlite3-sys/build.rs b/libsqlite3-sys/build.rs index 0757d9f..955f565 100644 --- a/libsqlite3-sys/build.rs +++ b/libsqlite3-sys/build.rs @@ -18,8 +18,8 @@ mod build { fs::copy("sqlite3/bindgen_bundled_version.rs", out_path) .expect("Could not copy bindings to output directory"); - cc::Build::new() - .file("sqlite3/sqlite3.c") + let mut cfg = cc::Build::new(); + cfg.file("sqlite3/sqlite3.c") .flag("-DSQLITE_CORE") .flag("-DSQLITE_DEFAULT_FOREIGN_KEYS=1") .flag("-DSQLITE_ENABLE_API_ARMOR") @@ -38,8 +38,11 @@ mod build { .flag("-DSQLITE_SOUNDEX") .flag("-DSQLITE_THREADSAFE=1") .flag("-DSQLITE_USE_URI") - .flag("-DHAVE_USLEEP=1") - .compile("libsqlite3.a"); + .flag("-DHAVE_USLEEP=1"); + if cfg!(feature = "unlock_notify") { + cfg.flag("-DSQLITE_ENABLE_UNLOCK_NOTIFY"); + } + cfg.compile("libsqlite3.a"); } } diff --git a/src/lib.rs b/src/lib.rs index 98e62da..d220e12 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -125,6 +125,7 @@ pub mod limits; mod hooks; #[cfg(feature = "hooks")] pub use hooks::*; +mod unlock_notify; // Number of cached prepared statements we'll hold on to. const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; @@ -862,16 +863,40 @@ impl InnerConnection { } let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() }; let c_sql = try!(str_to_cstring(sql)); + let len_with_nul = (sql.len() + 1) as c_int; let r = unsafe { - let len_with_nul = (sql.len() + 1) as c_int; - ffi::sqlite3_prepare_v2(self.db(), - c_sql.as_ptr(), - len_with_nul, - &mut c_stmt, - ptr::null_mut()) + if cfg!(feature = "unlock_notify") { + let mut rc; + loop { + rc = ffi::sqlite3_prepare_v2( + self.db(), + c_sql.as_ptr(), + len_with_nul, + &mut c_stmt, + ptr::null_mut(), + ); + if !unlock_notify::is_locked(self.db, rc) { + break; + } + rc = unlock_notify::wait_for_unlock_notify(self.db); + if rc != ffi::SQLITE_OK { + break; + } + } + rc + } else { + ffi::sqlite3_prepare_v2( + self.db(), + c_sql.as_ptr(), + len_with_nul, + &mut c_stmt, + ptr::null_mut(), + ) + } }; - self.decode_result(r) - .map(|_| Statement::new(conn, RawStatement::new(c_stmt))) + self.decode_result(r).map(|_| { + Statement::new(conn, RawStatement::new(c_stmt)) + }) } fn changes(&mut self) -> c_int { diff --git a/src/raw_statement.rs b/src/raw_statement.rs index d25c09f..3c33ea3 100644 --- a/src/raw_statement.rs +++ b/src/raw_statement.rs @@ -2,6 +2,7 @@ use std::ffi::CStr; use std::ptr; use std::os::raw::c_int; use super::ffi; +use super::unlock_notify; // Private newtype for raw sqlite3_stmts that finalize themselves when dropped. #[derive(Debug)] @@ -29,7 +30,24 @@ impl RawStatement { } pub fn step(&self) -> c_int { - unsafe { ffi::sqlite3_step(self.0) } + if cfg!(feature = "unlock_notify") { + let db = unsafe { ffi::sqlite3_db_handle(self.0) }; + let mut rc; + loop { + rc = unsafe { ffi::sqlite3_step(self.0) }; + if !unlock_notify::is_locked(db, rc) { + break; + } + rc = unlock_notify::wait_for_unlock_notify(db); + if rc != ffi::SQLITE_OK { + break; + } + self.reset(); + } + rc + } else { + unsafe { ffi::sqlite3_step(self.0) } + } } pub fn reset(&self) -> c_int { diff --git a/src/unlock_notify.rs b/src/unlock_notify.rs new file mode 100644 index 0000000..b11c432 --- /dev/null +++ b/src/unlock_notify.rs @@ -0,0 +1,129 @@ +//! [Unlock Notification](http://sqlite.org/unlock_notify.html) + +#[cfg(feature = "unlock_notify")] +use std::sync::{Condvar, Mutex}; +use std::os::raw::c_int; +#[cfg(feature = "unlock_notify")] +use std::os::raw::c_void; + +use ffi; + +#[cfg(feature = "unlock_notify")] +struct UnlockNotification { + cond: Condvar, // Condition variable to wait on + mutex: Mutex, // Mutex to protect structure +} + +#[cfg(feature = "unlock_notify")] +impl UnlockNotification { + fn new() -> UnlockNotification { + UnlockNotification { + cond: Condvar::new(), + mutex: Mutex::new(false), + } + } + + fn fired(&mut self) { + *self.mutex.lock().unwrap() = true; + self.cond.notify_one(); + } + + fn wait(&mut self) { + let mut fired = self.mutex.lock().unwrap(); + while !*fired { + fired = self.cond.wait(fired).unwrap(); + } + } +} + +/// This function is an unlock-notify callback +#[cfg(feature = "unlock_notify")] +unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) { + use std::slice::from_raw_parts; + let args = from_raw_parts(ap_arg, n_arg as usize); + for arg in args { + let un: &mut UnlockNotification = &mut *(*arg as *mut UnlockNotification); + un.fired(); + } +} + +#[cfg(feature = "unlock_notify")] +pub fn is_locked(db: *mut ffi::sqlite3, rc: c_int) -> bool { + rc == ffi::SQLITE_LOCKED_SHAREDCACHE || (rc & 0xFF) == ffi::SQLITE_LOCKED && unsafe { + ffi::sqlite3_extended_errcode(db) + } + == ffi::SQLITE_LOCKED_SHAREDCACHE +} + +/// This function assumes that an SQLite API call (either `sqlite3_prepare_v2()` +/// or `sqlite3_step()`) has just returned `SQLITE_LOCKED`. The argument is the +/// associated database connection. +/// +/// This function calls `sqlite3_unlock_notify()` to register for an +/// unlock-notify callback, then blocks until that callback is delivered +/// and returns `SQLITE_OK`. The caller should then retry the failed operation. +/// +/// Or, if `sqlite3_unlock_notify()` indicates that to block would deadlock +/// the system, then this function returns `SQLITE_LOCKED` immediately. In +/// this case the caller should not retry the operation and should roll +/// back the current transaction (if any). +#[cfg(feature = "unlock_notify")] +pub fn wait_for_unlock_notify(db: *mut ffi::sqlite3) -> c_int { + let mut un = UnlockNotification::new(); + /* Register for an unlock-notify callback. */ + let rc = unsafe { + ffi::sqlite3_unlock_notify( + db, + Some(unlock_notify_cb), + &mut un as *mut UnlockNotification as *mut c_void, + ) + }; + debug_assert!( + rc == ffi::SQLITE_LOCKED || rc == ffi::SQLITE_LOCKED_SHAREDCACHE || rc == ffi::SQLITE_OK + ); + if rc == ffi::SQLITE_OK { + un.wait(); + } + rc +} + +#[cfg(not(feature = "unlock_notify"))] +pub fn is_locked(_db: *mut ffi::sqlite3, _rc: c_int) -> bool { + unreachable!() +} + +#[cfg(not(feature = "unlock_notify"))] +pub fn wait_for_unlock_notify(_db: *mut ffi::sqlite3) -> c_int { + unreachable!() +} + +#[cfg(feature = "unlock_notify")] +#[cfg(test)] +mod test { + use std::sync::mpsc::sync_channel; + use std::thread; + use std::time; + use {Connection, OpenFlags, Result, Transaction, TransactionBehavior}; + + #[test] + fn test_unlock_notify() { + let url = "file::memory:?cache=shared"; + let flags = OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_URI; + let db1 = Connection::open_with_flags(url, flags).unwrap(); + db1.execute_batch("CREATE TABLE foo (x)").unwrap(); + let (rx, tx) = sync_channel(0); + let child = thread::spawn(move || { + let mut db2 = Connection::open_with_flags(url, flags).unwrap(); + let tx2 = Transaction::new(&mut db2, TransactionBehavior::Immediate).unwrap(); + tx2.execute_batch("INSERT INTO foo VALUES (42)").unwrap(); + rx.send(1).unwrap(); + let ten_millis = time::Duration::from_millis(10); + thread::sleep(ten_millis); + tx2.commit().unwrap(); + }); + assert_eq!(tx.recv().unwrap(), 1); + let the_answer: Result = db1.query_row("SELECT x FROM foo", &[], |r| r.get(0)); + assert_eq!(42i64, the_answer.unwrap()); + child.join().unwrap(); + } +}