diff --git a/src/busy.rs b/src/busy.rs index 2dca96d..e0ee835 100644 --- a/src/busy.rs +++ b/src/busy.rs @@ -1,6 +1,7 @@ ///! Busy handler (when the database is locked) use std::mem; use std::os::raw::{c_int, c_void}; +use std::panic::catch_unwind; use std::ptr; use std::time::Duration; @@ -48,7 +49,7 @@ impl Connection { pub fn busy_handler(&self, callback: Option bool>) -> Result<()> { unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int { let handler_fn: fn(i32) -> bool = mem::transmute(p_arg); - if handler_fn(count) { + if let Ok(true) = catch_unwind(|| handler_fn(count)) { 1 } else { 0 diff --git a/src/error.rs b/src/error.rs index f451b4e..4bcfdf5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -91,6 +91,9 @@ pub enum Error { #[cfg(feature = "vtab")] #[allow(dead_code)] ModuleError(String), + + #[cfg(feature = "functions")] + UnwindingPanic, } impl From for Error { @@ -151,6 +154,8 @@ impl fmt::Display for Error { Error::InvalidQuery => write!(f, "Query is not read-only"), #[cfg(feature = "vtab")] Error::ModuleError(ref desc) => write!(f, "{}", desc), + #[cfg(feature = "functions")] + Error::UnwindingPanic => write!(f, "unwinding panic"), } } } @@ -188,6 +193,8 @@ impl error::Error for Error { Error::InvalidQuery => "query is not read-only", #[cfg(feature = "vtab")] Error::ModuleError(ref desc) => desc, + #[cfg(feature = "functions")] + Error::UnwindingPanic => "unwinding panic", } } @@ -222,6 +229,9 @@ impl error::Error for Error { #[cfg(feature = "vtab")] Error::ModuleError(_) => None, + + #[cfg(feature = "functions")] + Error::UnwindingPanic => None, } } } diff --git a/src/functions.rs b/src/functions.rs index f119b74..c4f398d 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -55,6 +55,7 @@ //! ``` use std::error::Error as StdError; use std::os::raw::{c_int, c_void}; +use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; use std::ptr; use std::slice; @@ -193,6 +194,7 @@ impl<'a> Context<'a> { /// result. Implementations should be stateless. pub trait Aggregate where + A: RefUnwindSafe + UnwindSafe, T: ToSql, { /// Initializes the aggregation context. Will be called prior to the first @@ -250,7 +252,7 @@ impl Connection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result + Send + 'static, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, T: ToSql, { self.db @@ -271,6 +273,7 @@ impl Connection { aggr: D, ) -> Result<()> where + A: RefUnwindSafe + UnwindSafe, D: Aggregate, T: ToSql, { @@ -301,7 +304,7 @@ impl InnerConnection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result + Send + 'static, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, T: ToSql, { unsafe extern "C" fn call_boxed_closure( @@ -312,20 +315,28 @@ impl InnerConnection { F: FnMut(&Context<'_>) -> Result, T: ToSql, { - let ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), + let r = catch_unwind(|| { + let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F; + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + let ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_f)(&ctx) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, }; - let boxed_f: *mut F = ffi::sqlite3_user_data(ctx.ctx) as *mut F; - assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - - let t = (*boxed_f)(&ctx); let t = t.as_ref().map(|t| ToSql::to_sql(t)); match t { - Ok(Ok(ref value)) => set_result(ctx.ctx, value), - Ok(Err(err)) => report_error(ctx.ctx, &err), - Err(err) => report_error(ctx.ctx, err), + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), } } @@ -359,6 +370,7 @@ impl InnerConnection { aggr: D, ) -> Result<()> where + A: RefUnwindSafe + UnwindSafe, D: Aggregate, T: ToSql, { @@ -378,15 +390,10 @@ impl InnerConnection { argc: c_int, argv: *mut *mut sqlite3_value, ) where + A: RefUnwindSafe + UnwindSafe, D: Aggregate, T: ToSql, { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; - assert!( - !boxed_aggr.is_null(), - "Internal error - null aggregate pointer" - ); - let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { Some(pac) => pac, None => { @@ -395,32 +402,40 @@ impl InnerConnection { } }; - if (*pac as *mut A).is_null() { - *pac = Box::into_raw(Box::new((*boxed_aggr).init())); - } - - let mut ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).step(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, }; - - match (*boxed_aggr).step(&mut ctx, &mut **pac) { + match r { Ok(_) => {} - Err(err) => report_error(ctx.ctx, &err), + Err(err) => report_error(ctx, &err), }; } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where + A: RefUnwindSafe + UnwindSafe, D: Aggregate, T: ToSql, { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; - assert!( - !boxed_aggr.is_null(), - "Internal error - null aggregate pointer" - ); - // Within the xFinal callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. let a: Option = match aggregate_context(ctx, 0) { @@ -435,7 +450,21 @@ impl InnerConnection { None => None, }; - let t = (*boxed_aggr).finalize(a); + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).finalize(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; let t = t.as_ref().map(|t| ToSql::to_sql(t)); match t { Ok(Ok(ref value)) => set_result(ctx, value), diff --git a/src/hooks.rs b/src/hooks.rs index 947e7f4..7c26877 100644 --- a/src/hooks.rs +++ b/src/hooks.rs @@ -2,6 +2,7 @@ #![allow(non_camel_case_types)] use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; use std::ptr; use crate::ffi; @@ -146,8 +147,11 @@ impl InnerConnection { where F: FnMut() -> bool, { - let boxed_hook: *mut F = p_arg as *mut F; - if (*boxed_hook)() { + let r = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)() + }); + if let Ok(true) = r { 1 } else { 0 @@ -192,8 +196,10 @@ impl InnerConnection { where F: FnMut(), { - let boxed_hook: *mut F = p_arg as *mut F; - (*boxed_hook)(); + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)(); + }); } let free_rollback_hook = if hook.is_some() { @@ -239,8 +245,6 @@ impl InnerConnection { use std::ffi::CStr; use std::str; - let boxed_hook: *mut F = p_arg as *mut F; - let action = Action::from(action_code); let db_name = { let c_slice = CStr::from_ptr(db_str).to_bytes(); @@ -251,7 +255,10 @@ impl InnerConnection { str::from_utf8_unchecked(c_slice) }; - (*boxed_hook)(action, db_name, tbl_name, row_id); + let _ = catch_unwind(|| { + let boxed_hook: *mut F = p_arg as *mut F; + (*boxed_hook)(action, db_name, tbl_name, row_id); + }); } let free_update_hook = if hook.is_some() { diff --git a/src/trace.rs b/src/trace.rs index 810b8d0..ddf537f 100644 --- a/src/trace.rs +++ b/src/trace.rs @@ -3,6 +3,7 @@ use std::ffi::{CStr, CString}; use std::mem; use std::os::raw::{c_char, c_int, c_void}; +use std::panic::catch_unwind; use std::ptr; use std::time::Duration; @@ -27,7 +28,7 @@ pub unsafe fn config_log(callback: Option) -> Result<()> { let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; let s = String::from_utf8_lossy(c_slice); - callback(err, &s); + let _ = catch_unwind(|| callback(err, &s)); } let rc = match callback { @@ -72,7 +73,7 @@ impl Connection { let trace_fn: fn(&str) = mem::transmute(p_arg); let c_slice = CStr::from_ptr(z_sql).to_bytes(); let s = String::from_utf8_lossy(c_slice); - trace_fn(&s); + let _ = catch_unwind(|| trace_fn(&s)); } let c = self.db.borrow_mut(); @@ -106,7 +107,7 @@ impl Connection { nanoseconds / NANOS_PER_SEC, (nanoseconds % NANOS_PER_SEC) as u32, ); - profile_fn(&s, duration); + let _ = catch_unwind(|| profile_fn(&s, duration)); } let c = self.db.borrow_mut(); diff --git a/src/unlock_notify.rs b/src/unlock_notify.rs index 0f8f6fa..7295322 100644 --- a/src/unlock_notify.rs +++ b/src/unlock_notify.rs @@ -4,6 +4,8 @@ use std::os::raw::c_int; #[cfg(feature = "unlock_notify")] use std::os::raw::c_void; #[cfg(feature = "unlock_notify")] +use std::panic::catch_unwind; +#[cfg(feature = "unlock_notify")] use std::sync::{Condvar, Mutex}; use crate::ffi; @@ -42,8 +44,10 @@ unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) { use std::slice::from_raw_parts; let args = from_raw_parts(ap_arg, n_arg as usize); for arg in args { - let un: &mut UnlockNotification = &mut *(*arg as *mut UnlockNotification); - un.fired(); + let _ = catch_unwind(|| { + let un: &mut UnlockNotification = &mut *(*arg as *mut UnlockNotification); + un.fired() + }); } }