From 57a3a8f62e09204a6665687498bfd0aa3e6a1a7f Mon Sep 17 00:00:00 2001 From: gwenn <45554+gwenn@users.noreply.github.com> Date: Sat, 30 Mar 2024 17:01:44 +0100 Subject: [PATCH 1/4] Add bindings to automatic extension loading API (#1487) * Add bindings to automatic extension loading API it doesn't seem possible to directly register an `AutoExtension`. --- Cargo.toml | 3 + .../bindgen-bindings/bindgen_3.14.0.rs | 8 +-- libsqlite3-sys/build.rs | 8 +-- .../sqlcipher/bindgen_bundled_version.rs | 8 +-- .../sqlite3/bindgen_bundled_version.rs | 8 +-- src/auto_extension.rs | 57 +++++++++++++++++++ src/lib.rs | 2 + tests/auto_ext.rs | 41 +++++++++++++ 8 files changed, 119 insertions(+), 16 deletions(-) create mode 100644 src/auto_extension.rs create mode 100644 tests/auto_ext.rs diff --git a/Cargo.toml b/Cargo.toml index 562a745..b781bf4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,6 +140,9 @@ bencher = "0.1" path = "libsqlite3-sys" version = "0.28.0" +[[test]] +name = "auto_ext" + [[test]] name = "config_log" harness = false diff --git a/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0.rs b/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0.rs index a2af6f0..bb6f871 100644 --- a/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0.rs +++ b/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0.rs @@ -5,8 +5,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; @@ -16,8 +16,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; diff --git a/libsqlite3-sys/build.rs b/libsqlite3-sys/build.rs index b60ea37..987b34f 100644 --- a/libsqlite3-sys/build.rs +++ b/libsqlite3-sys/build.rs @@ -554,8 +554,8 @@ mod bindings { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; @@ -568,8 +568,8 @@ mod bindings { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; diff --git a/libsqlite3-sys/sqlcipher/bindgen_bundled_version.rs b/libsqlite3-sys/sqlcipher/bindgen_bundled_version.rs index e27d49e..7bac3ff 100644 --- a/libsqlite3-sys/sqlcipher/bindgen_bundled_version.rs +++ b/libsqlite3-sys/sqlcipher/bindgen_bundled_version.rs @@ -5,8 +5,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; @@ -16,8 +16,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; diff --git a/libsqlite3-sys/sqlite3/bindgen_bundled_version.rs b/libsqlite3-sys/sqlite3/bindgen_bundled_version.rs index 7c3a762..996fbc5 100644 --- a/libsqlite3-sys/sqlite3/bindgen_bundled_version.rs +++ b/libsqlite3-sys/sqlite3/bindgen_bundled_version.rs @@ -5,8 +5,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; @@ -16,8 +16,8 @@ extern "C" { xEntryPoint: ::std::option::Option< unsafe extern "C" fn( db: *mut sqlite3, - pzErrMsg: *mut *const ::std::os::raw::c_char, - pThunk: *const sqlite3_api_routines, + pzErrMsg: *mut *mut ::std::os::raw::c_char, + _: *const sqlite3_api_routines, ) -> ::std::os::raw::c_int, >, ) -> ::std::os::raw::c_int; diff --git a/src/auto_extension.rs b/src/auto_extension.rs new file mode 100644 index 0000000..acb7523 --- /dev/null +++ b/src/auto_extension.rs @@ -0,0 +1,57 @@ +//! Automatic axtension loading +use super::ffi; +use crate::error::{check, to_sqlite_error}; +use crate::{Connection, Result}; +use std::os::raw::{c_char, c_int}; + +/// Automatic extension initialization routine +pub type AutoExtension = fn(Connection) -> Result<()>; + +/// Raw automatic extension initialization routine +pub type RawAutoExtension = unsafe extern "C" fn( + db: *mut ffi::sqlite3, + pz_err_msg: *mut *mut c_char, + _: *const ffi::sqlite3_api_routines, +) -> c_int; + +/// Bridge bewteen `RawAutoExtension` and `AutoExtension` +/// +/// # Safety +/// * Opening a database from an auto-extension handler will lead to +/// an endless recursion of the auto-handler triggering itself +/// indirectly for each newly-opened database. +/// * Results are undefined if the given db is closed by an auto-extension. +/// * The list of auto-extensions should not be manipulated from an auto-extension. +pub unsafe fn init_auto_extension( + db: *mut ffi::sqlite3, + pz_err_msg: *mut *mut c_char, + ax: AutoExtension, +) -> c_int { + let c = Connection::from_handle(db); + match c.and_then(ax) { + Err(e) => to_sqlite_error(&e, pz_err_msg), + _ => ffi::SQLITE_OK, + } +} + +/// Register au auto-extension +/// +/// # Safety +/// * Opening a database from an auto-extension handler will lead to +/// an endless recursion of the auto-handler triggering itself +/// indirectly for each newly-opened database. +/// * Results are undefined if the given db is closed by an auto-extension. +/// * The list of auto-extensions should not be manipulated from an auto-extension. +pub unsafe fn register_auto_extension(ax: RawAutoExtension) -> Result<()> { + check(ffi::sqlite3_auto_extension(Some(ax))) +} + +/// Unregister the initialization routine +pub fn cancel_auto_extension(ax: RawAutoExtension) -> bool { + unsafe { ffi::sqlite3_cancel_auto_extension(Some(ax)) == 1 } +} + +/// Disable all automatic extensions previously registered +pub fn reset_auto_extension() { + unsafe { ffi::sqlite3_reset_auto_extension() } +} diff --git a/src/lib.rs b/src/lib.rs index 7572908..1df6e2b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,6 +93,8 @@ pub use rusqlite_macros::__bind; mod error; +#[cfg(not(feature = "loadable_extension"))] +pub mod auto_extension; #[cfg(feature = "backup")] #[cfg_attr(docsrs, doc(cfg(feature = "backup")))] pub mod backup; diff --git a/tests/auto_ext.rs b/tests/auto_ext.rs new file mode 100644 index 0000000..17ed682 --- /dev/null +++ b/tests/auto_ext.rs @@ -0,0 +1,41 @@ +#[cfg(all(feature = "bundled", not(feature = "loadable_extension")))] +#[test] +fn auto_ext() -> rusqlite::Result<()> { + use rusqlite::auto_extension::*; + use rusqlite::{ffi, Connection, Error, Result}; + use std::os::raw::{c_char, c_int}; + + fn test_ok(_: Connection) -> Result<()> { + Ok(()) + } + unsafe extern "C" fn sqlite_test_ok( + db: *mut ffi::sqlite3, + pz_err_msg: *mut *mut c_char, + _: *const ffi::sqlite3_api_routines, + ) -> c_int { + init_auto_extension(db, pz_err_msg, test_ok) + } + fn test_err(_: Connection) -> Result<()> { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_CORRUPT), + Some("AutoExtErr".to_owned()), + )) + } + unsafe extern "C" fn sqlite_test_err( + db: *mut ffi::sqlite3, + pz_err_msg: *mut *mut c_char, + _: *const ffi::sqlite3_api_routines, + ) -> c_int { + init_auto_extension(db, pz_err_msg, test_err) + } + + //assert!(!cancel_auto_extension(sqlite_test_ok)); + unsafe { register_auto_extension(sqlite_test_ok)? }; + Connection::open_in_memory()?; + assert!(cancel_auto_extension(sqlite_test_ok)); + assert!(!cancel_auto_extension(sqlite_test_ok)); + unsafe { register_auto_extension(sqlite_test_err)? }; + Connection::open_in_memory().unwrap_err(); + reset_auto_extension(); + Ok(()) +} From 19b20e0fc381ff31bd6e8171f0faef7923b92db4 Mon Sep 17 00:00:00 2001 From: gwenn <gtreguier@gmail.com> Date: Sun, 31 Mar 2024 11:11:19 +0200 Subject: [PATCH 2/4] Remove Ref/UnwindSafe constraint on FFI callback As suggested here: https://github.com/rusqlite/rusqlite/pull/1052#issuecomment-988455248 --- src/collation.rs | 6 +++--- src/functions.rs | 4 ++-- src/hooks.rs | 10 ++++----- src/session.rs | 54 +++++++++++++++++++++++++++--------------------- 4 files changed, 41 insertions(+), 33 deletions(-) diff --git a/src/collation.rs b/src/collation.rs index ade51e0..c467c62 100644 --- a/src/collation.rs +++ b/src/collation.rs @@ -1,7 +1,7 @@ //! Add, remove, or modify a collation use std::cmp::Ordering; use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, UnwindSafe}; +use std::panic::catch_unwind; use std::ptr; use std::slice; @@ -18,7 +18,7 @@ impl Connection { #[inline] pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> where - C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + C: Fn(&str, &str) -> Ordering + Send + 'static, { self.db .borrow_mut() @@ -44,7 +44,7 @@ impl Connection { impl InnerConnection { fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> where - C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, + C: Fn(&str, &str) -> Ordering + Send + 'static, { unsafe extern "C" fn call_boxed_closure<C>( arg1: *mut c_void, diff --git a/src/functions.rs b/src/functions.rs index f4a508c..7a00152 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -444,7 +444,7 @@ impl Connection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + F: FnMut(&Context<'_>) -> Result<T> + Send + 'static, T: SqlFnOutput, { self.db @@ -526,7 +526,7 @@ impl InnerConnection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, + F: FnMut(&Context<'_>) -> Result<T> + Send + 'static, T: SqlFnOutput, { unsafe extern "C" fn call_boxed_closure<F, T>( diff --git a/src/hooks.rs b/src/hooks.rs index 108334b..652d474 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -2,7 +2,7 @@ #![allow(non_camel_case_types)] use std::os::raw::{c_char, c_int, c_void}; -use std::panic::{catch_unwind, RefUnwindSafe}; +use std::panic::catch_unwind; use std::ptr; use crate::ffi; @@ -388,7 +388,7 @@ impl Connection { /// If the progress callback returns `true`, the operation is interrupted. pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) where - F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + F: FnMut() -> bool + Send + 'static, { self.db.borrow_mut().progress_handler(num_ops, handler); } @@ -398,7 +398,7 @@ impl Connection { #[inline] pub fn authorizer<'c, F>(&self, hook: Option<F>) where - F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static, { self.db.borrow_mut().authorizer(hook); } @@ -554,7 +554,7 @@ impl InnerConnection { fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) where - F: FnMut() -> bool + Send + RefUnwindSafe + 'static, + F: FnMut() -> bool + Send + 'static, { unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int where @@ -586,7 +586,7 @@ impl InnerConnection { fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>) where - F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, + F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static, { unsafe extern "C" fn call_boxed_closure<'c, F>( p_arg: *mut c_void, diff --git a/src/session.rs b/src/session.rs index c5c785d..a39de09 100644 --- a/src/session.rs +++ b/src/session.rs @@ -5,7 +5,7 @@ use std::ffi::CStr; use std::io::{Read, Write}; use std::marker::PhantomData; use std::os::raw::{c_char, c_int, c_uchar, c_void}; -use std::panic::{catch_unwind, RefUnwindSafe}; +use std::panic::catch_unwind; use std::ptr; use std::slice::{from_raw_parts, from_raw_parts_mut}; @@ -59,20 +59,22 @@ impl Session<'_> { /// Set a table filter pub fn table_filter<F>(&mut self, filter: Option<F>) where - F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, + F: Fn(&str) -> bool + Send + 'static, { unsafe extern "C" fn call_boxed_closure<F>( p_arg: *mut c_void, tbl_str: *const c_char, ) -> c_int where - F: Fn(&str) -> bool + RefUnwindSafe, + F: Fn(&str) -> bool, { - let boxed_filter: *mut F = p_arg as *mut F; let tbl_name = CStr::from_ptr(tbl_str).to_str(); c_int::from( - catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name"))) - .unwrap_or_default(), + catch_unwind(|| { + let boxed_filter: *mut F = p_arg.cast::<F>(); + (*boxed_filter)(tbl_name.expect("non-utf8 table name")) + }) + .unwrap_or_default(), ) } @@ -588,8 +590,8 @@ impl Connection { /// Apply a changeset to a database pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()> where - F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, - C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + F: Fn(&str) -> bool + Send + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static, { let db = self.db.borrow_mut().db; @@ -626,8 +628,8 @@ impl Connection { conflict: C, ) -> Result<()> where - F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, - C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + F: Fn(&str) -> bool + Send + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static, { let input_ref = &input; let db = self.db.borrow_mut().db; @@ -701,17 +703,21 @@ pub enum ConflictAction { unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int where - F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, - C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + F: Fn(&str) -> bool + Send + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static, { - let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); let tbl_name = CStr::from_ptr(tbl_str).to_str(); - match *tuple { - (Some(ref filter), _) => c_int::from( - catch_unwind(|| filter(tbl_name.expect("illegal table name"))).unwrap_or_default(), - ), - _ => unimplemented!(), - } + c_int::from( + catch_unwind(|| { + let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>(); + if let Some(ref filter) = (*tuple).0 { + filter(tbl_name.expect("illegal table name")) + } else { + true + } + }) + .unwrap_or_default(), + ) } unsafe extern "C" fn call_conflict<F, C>( @@ -720,13 +726,15 @@ unsafe extern "C" fn call_conflict<F, C>( p: *mut ffi::sqlite3_changeset_iter, ) -> c_int where - F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, - C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, + F: Fn(&str) -> bool + Send + 'static, + C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static, { - let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C); let conflict_type = ConflictType::from(e_conflict); let item = ChangesetItem { it: p }; - if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) { + if let Ok(action) = catch_unwind(|| { + let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>(); + (*tuple).1(conflict_type, item) + }) { action as c_int } else { ffi::SQLITE_CHANGESET_ABORT From d8bcd4d28aa7ff82a226f34f7b7bc49d03eff569 Mon Sep 17 00:00:00 2001 From: gwenn <45554+gwenn@users.noreply.github.com> Date: Sun, 31 Mar 2024 11:41:05 +0200 Subject: [PATCH 3/4] Check callbacks lifetime (#1052) Check callbacks lifetime --- src/collation.rs | 22 ++++++++++++ src/functions.rs | 21 +++++++++++ src/hooks.rs | 91 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+) diff --git a/src/collation.rs b/src/collation.rs index c467c62..cf4226b 100644 --- a/src/collation.rs +++ b/src/collation.rs @@ -42,6 +42,28 @@ impl Connection { } impl InnerConnection { + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.create_collation("foo", |_, _| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// std::cmp::Ordering::Equal + /// })?; + /// } + /// let value: String = db.query_row( + /// "WITH cte(bar) AS + /// (VALUES ('v1'),('v2'),('v3'),('v4'),('v5')) + /// SELECT DISTINCT bar COLLATE foo FROM cte;", + /// [], + /// |row| row.get(0), + /// )?; + /// assert_eq!(value, "v1"); + /// Ok(()) + /// } + /// ``` fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> where C: Fn(&str, &str) -> Ordering + Send + 'static, diff --git a/src/functions.rs b/src/functions.rs index 7a00152..e6ab3f2 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -518,6 +518,27 @@ impl Connection { } impl InnerConnection { + /// ```compile_fail + /// use rusqlite::{functions::FunctionFlags, Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.create_scalar_function( + /// "test", + /// 0, + /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + /// |_| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// Ok(true) + /// }, + /// ); + /// } + /// let result: Result<bool> = db.query_row("SELECT test()", [], |r| r.get(0)); + /// assert!(result?); + /// Ok(()) + /// } + /// ``` fn create_scalar_function<F, T>( &mut self, fn_name: &str, diff --git a/src/hooks.rs b/src/hooks.rs index 652d474..2aac4c2 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -414,6 +414,27 @@ impl InnerConnection { self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>); } + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.commit_hook(Some(|| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// true + /// })); + /// } + /// assert!(db + /// .execute_batch( + /// "BEGIN; + /// CREATE TABLE foo (t TEXT); + /// COMMIT;", + /// ) + /// .is_err()); + /// Ok(()) + /// } + /// ``` fn commit_hook<F>(&mut self, hook: Option<F>) where F: FnMut() -> bool + Send + 'static, @@ -459,6 +480,26 @@ impl InnerConnection { self.free_commit_hook = free_commit_hook; } + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.rollback_hook(Some(|| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// })); + /// } + /// assert!(db + /// .execute_batch( + /// "BEGIN; + /// CREATE TABLE foo (t TEXT); + /// ROLLBACK;", + /// ) + /// .is_err()); + /// Ok(()) + /// } + /// ``` fn rollback_hook<F>(&mut self, hook: Option<F>) where F: FnMut() + Send + 'static, @@ -500,6 +541,19 @@ impl InnerConnection { self.free_rollback_hook = free_rollback_hook; } + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.update_hook(Some(|_, _: &str, _: &str, _| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// })); + /// } + /// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;") + /// } + /// ``` fn update_hook<F>(&mut self, hook: Option<F>) where F: FnMut(Action, &str, &str, i64) + Send + 'static, @@ -552,6 +606,26 @@ impl InnerConnection { self.free_update_hook = free_update_hook; } + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.progress_handler( + /// 1, + /// Some(|| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// true + /// }), + /// ); + /// } + /// assert!(db + /// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + /// .is_err()); + /// Ok(()) + /// } + /// ``` fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) where F: FnMut() -> bool + Send + 'static, @@ -584,6 +658,23 @@ impl InnerConnection { }; } + /// ```compile_fail + /// use rusqlite::{Connection, Result}; + /// fn main() -> Result<()> { + /// let db = Connection::open_in_memory()?; + /// { + /// let mut called = std::sync::atomic::AtomicBool::new(false); + /// db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| { + /// called.store(true, std::sync::atomic::Ordering::Relaxed); + /// rusqlite::hooks::Authorization::Deny + /// })); + /// } + /// assert!(db + /// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;") + /// .is_err()); + /// Ok(()) + /// } + /// ``` fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>) where F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static, From a0b410eb8641545ea88b515e45383043375cf85e Mon Sep 17 00:00:00 2001 From: gwenn <45554+gwenn@users.noreply.github.com> Date: Sun, 31 Mar 2024 13:22:00 +0200 Subject: [PATCH 4/4] Use catch_unwind in init_auto_extension (#1489) Use catch_unwind in init_auto_extension --- src/auto_extension.rs | 11 ++++++++--- src/error.rs | 5 ----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/auto_extension.rs b/src/auto_extension.rs index acb7523..b996006 100644 --- a/src/auto_extension.rs +++ b/src/auto_extension.rs @@ -1,8 +1,9 @@ //! Automatic axtension loading use super::ffi; use crate::error::{check, to_sqlite_error}; -use crate::{Connection, Result}; +use crate::{Connection, Error, Result}; use std::os::raw::{c_char, c_int}; +use std::panic::catch_unwind; /// Automatic extension initialization routine pub type AutoExtension = fn(Connection) -> Result<()>; @@ -27,8 +28,12 @@ pub unsafe fn init_auto_extension( pz_err_msg: *mut *mut c_char, ax: AutoExtension, ) -> c_int { - let c = Connection::from_handle(db); - match c.and_then(ax) { + let r = catch_unwind(|| { + let c = Connection::from_handle(db); + c.and_then(ax) + }) + .unwrap_or_else(|_| Err(Error::UnwindingPanic)); + match r { Err(e) => to_sqlite_error(&e, pz_err_msg), _ => ffi::SQLITE_OK, } diff --git a/src/error.rs b/src/error.rs index a9d557c..2946382 100644 --- a/src/error.rs +++ b/src/error.rs @@ -102,8 +102,6 @@ pub enum Error { ModuleError(String), /// An unwinding panic occurs in an UDF (user-defined function). - #[cfg(feature = "functions")] - #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] UnwindingPanic, /// An error returned when @@ -185,7 +183,6 @@ impl PartialEq for Error { (Error::InvalidQuery, Error::InvalidQuery) => true, #[cfg(feature = "vtab")] (Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2, - #[cfg(feature = "functions")] (Error::UnwindingPanic, Error::UnwindingPanic) => true, #[cfg(feature = "functions")] (Error::GetAuxWrongType, Error::GetAuxWrongType) => true, @@ -318,7 +315,6 @@ impl fmt::Display for Error { Error::InvalidQuery => write!(f, "Query is not read-only"), #[cfg(feature = "vtab")] Error::ModuleError(ref desc) => write!(f, "{desc}"), - #[cfg(feature = "functions")] Error::UnwindingPanic => write!(f, "unwinding panic"), #[cfg(feature = "functions")] Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), @@ -375,7 +371,6 @@ impl error::Error for Error { #[cfg(feature = "vtab")] Error::ModuleError(_) => None, - #[cfg(feature = "functions")] Error::UnwindingPanic => None, #[cfg(feature = "functions")]