3
0
mirror of https://github.com/isar/rusqlite.git synced 2025-04-01 11:32:57 +08:00

Remove Ref/UnwindSafe constraint on FFI callback

As suggested here:
https://github.com/rusqlite/rusqlite/pull/1052#issuecomment-988455248
This commit is contained in:
gwenn 2024-03-31 11:11:19 +02:00
parent 57a3a8f62e
commit 19b20e0fc3
4 changed files with 41 additions and 33 deletions

@ -1,7 +1,7 @@
//! Add, remove, or modify a collation //! Add, remove, or modify a collation
use std::cmp::Ordering; use std::cmp::Ordering;
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe}; use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use std::slice; use std::slice;
@ -18,7 +18,7 @@ impl Connection {
#[inline] #[inline]
pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()> pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
where where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, C: Fn(&str, &str) -> Ordering + Send + 'static,
{ {
self.db self.db
.borrow_mut() .borrow_mut()
@ -44,7 +44,7 @@ impl Connection {
impl InnerConnection { impl InnerConnection {
fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()> fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
where where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, C: Fn(&str, &str) -> Ordering + Send + 'static,
{ {
unsafe extern "C" fn call_boxed_closure<C>( unsafe extern "C" fn call_boxed_closure<C>(
arg1: *mut c_void, arg1: *mut c_void,

@ -444,7 +444,7 @@ impl Connection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + 'static,
T: SqlFnOutput, T: SqlFnOutput,
{ {
self.db self.db
@ -526,7 +526,7 @@ impl InnerConnection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + 'static,
T: SqlFnOutput, T: SqlFnOutput,
{ {
unsafe extern "C" fn call_boxed_closure<F, T>( unsafe extern "C" fn call_boxed_closure<F, T>(

@ -2,7 +2,7 @@
#![allow(non_camel_case_types)] #![allow(non_camel_case_types)]
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe}; use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use crate::ffi; use crate::ffi;
@ -388,7 +388,7 @@ impl Connection {
/// If the progress callback returns `true`, the operation is interrupted. /// If the progress callback returns `true`, the operation is interrupted.
pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>) pub fn progress_handler<F>(&self, num_ops: c_int, handler: Option<F>)
where where
F: FnMut() -> bool + Send + RefUnwindSafe + 'static, F: FnMut() -> bool + Send + 'static,
{ {
self.db.borrow_mut().progress_handler(num_ops, handler); self.db.borrow_mut().progress_handler(num_ops, handler);
} }
@ -398,7 +398,7 @@ impl Connection {
#[inline] #[inline]
pub fn authorizer<'c, F>(&self, hook: Option<F>) pub fn authorizer<'c, F>(&self, hook: Option<F>)
where where
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
{ {
self.db.borrow_mut().authorizer(hook); self.db.borrow_mut().authorizer(hook);
} }
@ -554,7 +554,7 @@ impl InnerConnection {
fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>) fn progress_handler<F>(&mut self, num_ops: c_int, handler: Option<F>)
where where
F: FnMut() -> bool + Send + RefUnwindSafe + 'static, F: FnMut() -> bool + Send + 'static,
{ {
unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int unsafe extern "C" fn call_boxed_closure<F>(p_arg: *mut c_void) -> c_int
where where
@ -586,7 +586,7 @@ impl InnerConnection {
fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>) fn authorizer<'c, F>(&'c mut self, authorizer: Option<F>)
where where
F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + RefUnwindSafe + 'static, F: for<'r> FnMut(AuthContext<'r>) -> Authorization + Send + 'static,
{ {
unsafe extern "C" fn call_boxed_closure<'c, F>( unsafe extern "C" fn call_boxed_closure<'c, F>(
p_arg: *mut c_void, p_arg: *mut c_void,

@ -5,7 +5,7 @@ use std::ffi::CStr;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::marker::PhantomData; use std::marker::PhantomData;
use std::os::raw::{c_char, c_int, c_uchar, c_void}; use std::os::raw::{c_char, c_int, c_uchar, c_void};
use std::panic::{catch_unwind, RefUnwindSafe}; use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use std::slice::{from_raw_parts, from_raw_parts_mut}; use std::slice::{from_raw_parts, from_raw_parts_mut};
@ -59,20 +59,22 @@ impl Session<'_> {
/// Set a table filter /// Set a table filter
pub fn table_filter<F>(&mut self, filter: Option<F>) pub fn table_filter<F>(&mut self, filter: Option<F>)
where where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, F: Fn(&str) -> bool + Send + 'static,
{ {
unsafe extern "C" fn call_boxed_closure<F>( unsafe extern "C" fn call_boxed_closure<F>(
p_arg: *mut c_void, p_arg: *mut c_void,
tbl_str: *const c_char, tbl_str: *const c_char,
) -> c_int ) -> c_int
where where
F: Fn(&str) -> bool + RefUnwindSafe, F: Fn(&str) -> bool,
{ {
let boxed_filter: *mut F = p_arg as *mut F;
let tbl_name = CStr::from_ptr(tbl_str).to_str(); let tbl_name = CStr::from_ptr(tbl_str).to_str();
c_int::from( c_int::from(
catch_unwind(|| (*boxed_filter)(tbl_name.expect("non-utf8 table name"))) catch_unwind(|| {
.unwrap_or_default(), let boxed_filter: *mut F = p_arg.cast::<F>();
(*boxed_filter)(tbl_name.expect("non-utf8 table name"))
})
.unwrap_or_default(),
) )
} }
@ -588,8 +590,8 @@ impl Connection {
/// Apply a changeset to a database /// Apply a changeset to a database
pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()> pub fn apply<F, C>(&self, cs: &Changeset, filter: Option<F>, conflict: C) -> Result<()>
where where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{ {
let db = self.db.borrow_mut().db; let db = self.db.borrow_mut().db;
@ -626,8 +628,8 @@ impl Connection {
conflict: C, conflict: C,
) -> Result<()> ) -> Result<()>
where where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{ {
let input_ref = &input; let input_ref = &input;
let db = self.db.borrow_mut().db; let db = self.db.borrow_mut().db;
@ -701,17 +703,21 @@ pub enum ConflictAction {
unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int unsafe extern "C" fn call_filter<F, C>(p_ctx: *mut c_void, tbl_str: *const c_char) -> c_int
where where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{ {
let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C);
let tbl_name = CStr::from_ptr(tbl_str).to_str(); let tbl_name = CStr::from_ptr(tbl_str).to_str();
match *tuple { c_int::from(
(Some(ref filter), _) => c_int::from( catch_unwind(|| {
catch_unwind(|| filter(tbl_name.expect("illegal table name"))).unwrap_or_default(), let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>();
), if let Some(ref filter) = (*tuple).0 {
_ => unimplemented!(), filter(tbl_name.expect("illegal table name"))
} } else {
true
}
})
.unwrap_or_default(),
)
} }
unsafe extern "C" fn call_conflict<F, C>( unsafe extern "C" fn call_conflict<F, C>(
@ -720,13 +726,15 @@ unsafe extern "C" fn call_conflict<F, C>(
p: *mut ffi::sqlite3_changeset_iter, p: *mut ffi::sqlite3_changeset_iter,
) -> c_int ) -> c_int
where where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static, F: Fn(&str) -> bool + Send + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static, C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + 'static,
{ {
let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C);
let conflict_type = ConflictType::from(e_conflict); let conflict_type = ConflictType::from(e_conflict);
let item = ChangesetItem { it: p }; let item = ChangesetItem { it: p };
if let Ok(action) = catch_unwind(|| (*tuple).1(conflict_type, item)) { if let Ok(action) = catch_unwind(|| {
let tuple: *mut (Option<F>, C) = p_ctx.cast::<(Option<F>, C)>();
(*tuple).1(conflict_type, item)
}) {
action as c_int action as c_int
} else { } else {
ffi::SQLITE_CHANGESET_ABORT ffi::SQLITE_CHANGESET_ABORT