Merge remote-tracking branch 'upstream/master' into preupdate_hook

This commit is contained in:
Austin Schey 2024-03-31 07:57:20 -05:00
commit 0d56555875
13 changed files with 299 additions and 54 deletions

View File

@ -141,6 +141,9 @@ bencher = "0.1"
path = "libsqlite3-sys"
version = "0.28.0"
[[test]]
name = "auto_ext"
[[test]]
name = "config_log"
harness = false

View File

@ -5,8 +5,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;
@ -16,8 +16,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;

View File

@ -554,8 +554,8 @@ mod bindings {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;
@ -568,8 +568,8 @@ mod bindings {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;

View File

@ -5,8 +5,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;
@ -16,8 +16,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;

View File

@ -5,8 +5,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;
@ -16,8 +16,8 @@ extern "C" {
xEntryPoint: ::std::option::Option<
unsafe extern "C" fn(
db: *mut sqlite3,
pzErrMsg: *mut *const ::std::os::raw::c_char,
pThunk: *const sqlite3_api_routines,
pzErrMsg: *mut *mut ::std::os::raw::c_char,
_: *const sqlite3_api_routines,
) -> ::std::os::raw::c_int,
>,
) -> ::std::os::raw::c_int;

62
src/auto_extension.rs Normal file
View File

@ -0,0 +1,62 @@
//! Automatic axtension loading
use super::ffi;
use crate::error::{check, to_sqlite_error};
use crate::{Connection, Error, Result};
use std::os::raw::{c_char, c_int};
use std::panic::catch_unwind;
/// Automatic extension initialization routine
pub type AutoExtension = fn(Connection) -> Result<()>;
/// Raw automatic extension initialization routine
pub type RawAutoExtension = unsafe extern "C" fn(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
_: *const ffi::sqlite3_api_routines,
) -> c_int;
/// Bridge bewteen `RawAutoExtension` and `AutoExtension`
///
/// # Safety
/// * Opening a database from an auto-extension handler will lead to
/// an endless recursion of the auto-handler triggering itself
/// indirectly for each newly-opened database.
/// * Results are undefined if the given db is closed by an auto-extension.
/// * The list of auto-extensions should not be manipulated from an auto-extension.
pub unsafe fn init_auto_extension(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
ax: AutoExtension,
) -> c_int {
let r = catch_unwind(|| {
let c = Connection::from_handle(db);
c.and_then(ax)
})
.unwrap_or_else(|_| Err(Error::UnwindingPanic));
match r {
Err(e) => to_sqlite_error(&e, pz_err_msg),
_ => ffi::SQLITE_OK,
}
}
/// Register au auto-extension
///
/// # Safety
/// * Opening a database from an auto-extension handler will lead to
/// an endless recursion of the auto-handler triggering itself
/// indirectly for each newly-opened database.
/// * Results are undefined if the given db is closed by an auto-extension.
/// * The list of auto-extensions should not be manipulated from an auto-extension.
pub unsafe fn register_auto_extension(ax: RawAutoExtension) -> Result<()> {
check(ffi::sqlite3_auto_extension(Some(ax)))
}
/// Unregister the initialization routine
pub fn cancel_auto_extension(ax: RawAutoExtension) -> bool {
unsafe { ffi::sqlite3_cancel_auto_extension(Some(ax)) == 1 }
}
/// Disable all automatic extensions previously registered
pub fn reset_auto_extension() {
unsafe { ffi::sqlite3_reset_auto_extension() }
}

View File

@ -1,7 +1,7 @@
//! Add, remove, or modify a collation
use std::cmp::Ordering;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe};
use std::panic::catch_unwind;
use std::ptr;
use std::slice;
@ -18,7 +18,7 @@ impl Connection {
#[inline]
pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
C: Fn(&str, &str) -> Ordering + Send + 'static,
{
self.db
.borrow_mut()
@ -42,9 +42,31 @@ impl Connection {
}
impl InnerConnection {
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.create_collation("foo", |_, _| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// std::cmp::Ordering::Equal
/// })?;
/// }
/// let value: String = db.query_row(
/// "WITH cte(bar) AS
/// (VALUES ('v1'),('v2'),('v3'),('v4'),('v5'))
/// SELECT DISTINCT bar COLLATE foo FROM cte;",
/// [],
/// |row| row.get(0),
/// )?;
/// assert_eq!(value, "v1");
/// Ok(())
/// }
/// ```
fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
C: Fn(&str, &str) -> Ordering + Send + 'static,
{
unsafe extern "C" fn call_boxed_closure<C>(
arg1: *mut c_void,

View File

@ -102,8 +102,6 @@ pub enum Error {
ModuleError(String),
/// An unwinding panic occurs in an UDF (user-defined function).
#[cfg(feature = "functions")]
#[cfg_attr(docsrs, doc(cfg(feature = "functions")))]
UnwindingPanic,
/// An error returned when
@ -185,7 +183,6 @@ impl PartialEq for Error {
(Error::InvalidQuery, Error::InvalidQuery) => true,
#[cfg(feature = "vtab")]
(Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2,
#[cfg(feature = "functions")]
(Error::UnwindingPanic, Error::UnwindingPanic) => true,
#[cfg(feature = "functions")]
(Error::GetAuxWrongType, Error::GetAuxWrongType) => true,
@ -318,7 +315,6 @@ impl fmt::Display for Error {
Error::InvalidQuery => write!(f, "Query is not read-only"),
#[cfg(feature = "vtab")]
Error::ModuleError(ref desc) => write!(f, "{desc}"),
#[cfg(feature = "functions")]
Error::UnwindingPanic => write!(f, "unwinding panic"),
#[cfg(feature = "functions")]
Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"),
@ -375,7 +371,6 @@ impl error::Error for Error {
#[cfg(feature = "vtab")]
Error::ModuleError(_) => None,
#[cfg(feature = "functions")]
Error::UnwindingPanic => None,
#[cfg(feature = "functions")]

View File

@ -444,7 +444,7 @@ impl Connection {
x_func: F,
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
F: FnMut(&Context<'_>) -> Result<T> + Send + 'static,
T: SqlFnOutput,
{
self.db
@ -518,6 +518,27 @@ impl Connection {
}
impl InnerConnection {
/// ```compile_fail
/// use rusqlite::{functions::FunctionFlags, Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.create_scalar_function(
/// "test",
/// 0,
/// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
/// |_| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// Ok(true)
/// },
/// );
/// }
/// let result: Result<bool> = db.query_row("SELECT test()", [], |r| r.get(0));
/// assert!(result?);
/// Ok(())
/// }
/// ```
fn create_scalar_function<F, T>(
&mut self,
fn_name: &str,
@ -526,7 +547,7 @@ impl InnerConnection {
x_func: F,
) -> Result<()>
where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
F: FnMut(&Context<'_>) -> Result<T> + Send + 'static,
T: SqlFnOutput,
{
unsafe extern "C" fn call_boxed_closure<F, T>(

View File

@ -2,7 +2,7 @@
#![allow(non_camel_case_types)]
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe};
use std::panic::catch_unwind;
use std::ptr;
use crate::ffi;
@ -394,7 +394,7 @@ impl Connection {
/// If the progress callback returns `true`, the operation is interrupted.
pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
where
F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
F: FnMut() -> bool + Send + 'static,
{
self.db.borrow_mut().progress_handler(num_ops, handler);
}
@ -404,7 +404,7 @@ impl Connection {
#[inline]
pub fn authorizer<'c, F>(&self, hook: Option<F>)
where
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
{
self.db.borrow_mut().authorizer(hook);
}
@ -420,6 +420,27 @@ impl InnerConnection {
self.authorizer(None::<fn(AuthContext<'_>) -> Authorization>);
}
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.commit_hook(Some(|| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// true
/// }));
/// }
/// assert!(db
/// .execute_batch(
/// "BEGIN;
/// CREATE TABLE foo (t TEXT);
/// COMMIT;",
/// )
/// .is_err());
/// Ok(())
/// }
/// ```
fn commit_hook<F>(&mut self, hook: Option<F>)
where
F: FnMut() -> bool + Send + 'static,
@ -465,6 +486,26 @@ impl InnerConnection {
self.free_commit_hook = free_commit_hook;
}
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.rollback_hook(Some(|| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// }));
/// }
/// assert!(db
/// .execute_batch(
/// "BEGIN;
/// CREATE TABLE foo (t TEXT);
/// ROLLBACK;",
/// )
/// .is_err());
/// Ok(())
/// }
/// ```
fn rollback_hook<F>(&mut self, hook: Option<F>)
where
F: FnMut() + Send + 'static,
@ -506,6 +547,19 @@ impl InnerConnection {
self.free_rollback_hook = free_rollback_hook;
}
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.update_hook(Some(|_, _: &str, _: &str, _| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// }));
/// }
/// db.execute_batch("CREATE TABLE foo AS SELECT 1 AS bar;")
/// }
/// ```
fn update_hook<F>(&mut self, hook: Option<F>)
where
F: FnMut(Action, &str, &str, i64) + Send + 'static,
@ -558,9 +612,29 @@ impl InnerConnection {
self.free_update_hook = free_update_hook;
}
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.progress_handler(
/// 1,
/// Some(|| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// true
/// }),
/// );
/// }
/// assert!(db
/// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
/// .is_err());
/// Ok(())
/// }
/// ```
fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
where
F: FnMut() -> bool + Send + RefUnwindSafe + 'static,
F: FnMut() -> bool + Send + 'static,
{
unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
where
@ -590,9 +664,26 @@ impl InnerConnection {
};
}
/// ```compile_fail
/// use rusqlite::{Connection, Result};
/// fn main() -> Result<()> {
/// let db = Connection::open_in_memory()?;
/// {
/// let mut called = std::sync::atomic::AtomicBool::new(false);
/// db.authorizer(Some(|_: rusqlite::hooks::AuthContext<'_>| {
/// called.store(true, std::sync::atomic::Ordering::Relaxed);
/// rusqlite::hooks::Authorization::Deny
/// }));
/// }
/// assert!(db
/// .execute_batch("BEGIN; CREATE TABLE foo (t TEXT); COMMIT;")
/// .is_err());
/// Ok(())
/// }
/// ```
fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
where
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static,
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
{
unsafe extern "C" fn call_boxed_closure<'c, F>(
p_arg: *mut c_void,

View File

@ -93,6 +93,8 @@ pub use rusqlite_macros::__bind;
mod error;
#[cfg(not(feature = "loadable_extension"))]
pub mod auto_extension;
#[cfg(feature = "backup")]
#[cfg_attr(docsrs, doc(cfg(feature = "backup")))]
pub mod backup;

View File

@ -5,7 +5,7 @@ use std::ffi::CStr;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_uchar, c_void};
use std::panic::{catch_unwind, RefUnwindSafe};
use std::panic::catch_unwind;
use std::ptr;
use std::slice::{from_raw_parts, from_raw_parts_mut};
@ -59,20 +59,22 @@ impl Session<'_> {
/// Set a table filter
pub fn table_filter<F>(&mut self, filter: Option<F>)
where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
F: Fn(&str) -> bool + Send + 'static,
{
unsafe extern "C" fn call_boxed_closure<F>(
p_arg: *mut c_void,
tbl_str: *const c_char,
) -> c_int
where
F: Fn(&str) -> bool + RefUnwindSafe,
F: Fn(&str) -> bool,
{
let boxed_filter: *mut F = p_arg as *mut F;
let tbl_name = CStr::from_ptr(tbl_str).to_str();
c_int::from(
catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name")))
.unwrap_or_default(),
catch_unwind(|| {
let boxed_filter: *mut F = p_arg.cast::<F>();
(*boxed_filter)(tbl_name.expect("non-utf8 table name"))
})
.unwrap_or_default(),
)
}
@ -588,8 +590,8 @@ impl Connection {
/// Apply a changeset to a database
pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()>
where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static,
F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{
let db = self.db.borrow_mut().db;
@ -626,8 +628,8 @@ impl Connection {
conflict: C,
) -> Result<()>
where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static,
F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{
let input_ref = &input;
let db = self.db.borrow_mut().db;
@ -701,17 +703,21 @@ pub enum ConflictAction {
unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int
where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static,
F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{
let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C);
let tbl_name = CStr::from_ptr(tbl_str).to_str();
match *tuple {
(Some(ref filter), _) => c_int::from(
catch_unwind(|| filter(tbl_name.expect("illegal table name"))).unwrap_or_default(),
),
_ => unimplemented!(),
}
c_int::from(
catch_unwind(|| {
let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>();
if let Some(ref filter) = (*tuple).0 {
filter(tbl_name.expect("illegal table name"))
} else {
true
}
})
.unwrap_or_default(),
)
}
unsafe extern "C" fn call_conflict<F, C>(
@ -720,13 +726,15 @@ unsafe extern "C" fn call_conflict<F, C>(
p: *mut ffi::sqlite3_changeset_iter,
) -> c_int
where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static,
F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{
let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C);
let conflict_type = ConflictType::from(e_conflict);
let item = ChangesetItem { it: p };
if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) {
if let Ok(action) = catch_unwind(|| {
let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>();
(*tuple).1(conflict_type, item)
}) {
action as c_int
} else {
ffi::SQLITE_CHANGESET_ABORT

41
tests/auto_ext.rs Normal file
View File

@ -0,0 +1,41 @@
#[cfg(all(feature = "bundled", not(feature = "loadable_extension")))]
#[test]
fn auto_ext() -> rusqlite::Result<()> {
use rusqlite::auto_extension::*;
use rusqlite::{ffi, Connection, Error, Result};
use std::os::raw::{c_char, c_int};
fn test_ok(_: Connection) -> Result<()> {
Ok(())
}
unsafe extern "C" fn sqlite_test_ok(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
_: *const ffi::sqlite3_api_routines,
) -> c_int {
init_auto_extension(db, pz_err_msg, test_ok)
}
fn test_err(_: Connection) -> Result<()> {
Err(Error::SqliteFailure(
ffi::Error::new(ffi::SQLITE_CORRUPT),
Some("AutoExtErr".to_owned()),
))
}
unsafe extern "C" fn sqlite_test_err(
db: *mut ffi::sqlite3,
pz_err_msg: *mut *mut c_char,
_: *const ffi::sqlite3_api_routines,
) -> c_int {
init_auto_extension(db, pz_err_msg, test_err)
}
//assert!(!cancel_auto_extension(sqlite_test_ok));
unsafe { register_auto_extension(sqlite_test_ok)? };
Connection::open_in_memory()?;
assert!(cancel_auto_extension(sqlite_test_ok));
assert!(!cancel_auto_extension(sqlite_test_ok));
unsafe { register_auto_extension(sqlite_test_err)? };
Connection::open_in_memory().unwrap_err();
reset_auto_extension();
Ok(())
}