Callbacks must not be able to unwind into sqlite code

This commit is contained in:
gwenn 2018-12-16 09:40:14 +01:00
parent bdfc2dfc54
commit bd9b850c43
6 changed files with 100 additions and 48 deletions

View File

@ -1,6 +1,7 @@
///! Busy handler (when the database is locked) ///! Busy handler (when the database is locked)
use std::mem; use std::mem;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use std::time::Duration; use std::time::Duration;
@ -48,7 +49,7 @@ impl Connection {
pub fn busy_handler(&self, callback: Option<fn(i32) -> bool>) -> Result<()> { pub fn busy_handler(&self, callback: Option<fn(i32) -> bool>) -> Result<()> {
unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int { unsafe extern "C" fn busy_handler_callback(p_arg: *mut c_void, count: c_int) -> c_int {
let handler_fn: fn(i32) -> bool = mem::transmute(p_arg); let handler_fn: fn(i32) -> bool = mem::transmute(p_arg);
if handler_fn(count) { if let Ok(true) = catch_unwind(|| handler_fn(count)) {
1 1
} else { } else {
0 0

View File

@ -91,6 +91,9 @@ pub enum Error {
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
#[allow(dead_code)] #[allow(dead_code)]
ModuleError(String), ModuleError(String),
#[cfg(feature = "functions")]
UnwindingPanic,
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -151,6 +154,8 @@ impl fmt::Display for Error {
Error::InvalidQuery => write!(f, "Query is not read-only"), Error::InvalidQuery => write!(f, "Query is not read-only"),
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
Error::ModuleError(ref desc) => write!(f, "{}", desc), Error::ModuleError(ref desc) => write!(f, "{}", desc),
#[cfg(feature = "functions")]
Error::UnwindingPanic => write!(f, "unwinding panic"),
} }
} }
} }
@ -188,6 +193,8 @@ impl error::Error for Error {
Error::InvalidQuery => "query is not read-only", Error::InvalidQuery => "query is not read-only",
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
Error::ModuleError(ref desc) => desc, Error::ModuleError(ref desc) => desc,
#[cfg(feature = "functions")]
Error::UnwindingPanic => "unwinding panic",
} }
} }
@ -222,6 +229,9 @@ impl error::Error for Error {
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
Error::ModuleError(_) => None, Error::ModuleError(_) => None,
#[cfg(feature = "functions")]
Error::UnwindingPanic => None,
} }
} }
} }

View File

@ -55,6 +55,7 @@
//! ``` //! ```
use std::error::Error as StdError; use std::error::Error as StdError;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr; use std::ptr;
use std::slice; use std::slice;
@ -189,6 +190,7 @@ impl<'a> Context<'a> {
/// result. Implementations should be stateless. /// result. Implementations should be stateless.
pub trait Aggregate<A, T> pub trait Aggregate<A, T>
where where
A: RefUnwindSafe + UnwindSafe,
T: ToSql, T: ToSql,
{ {
/// Initializes the aggregation context. Will be called prior to the first /// Initializes the aggregation context. Will be called prior to the first
@ -246,7 +248,7 @@ impl Connection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<T> + Send + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql, T: ToSql,
{ {
self.db self.db
@ -267,6 +269,7 @@ impl Connection {
aggr: D, aggr: D,
) -> Result<()> ) -> Result<()>
where where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: ToSql,
{ {
@ -297,7 +300,7 @@ impl InnerConnection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<T> + Send + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql, T: ToSql,
{ {
unsafe extern "C" fn call_boxed_closure<F, T>( unsafe extern "C" fn call_boxed_closure<F, T>(
@ -308,20 +311,28 @@ impl InnerConnection {
F: FnMut(&Context<'_>) -> Result<T>, F: FnMut(&Context<'_>) -> Result<T>,
T: ToSql, T: ToSql,
{ {
let r = catch_unwind(|| {
let boxed_f: *mut F = ffi::sqlite3_user_data(ctx) as *mut F;
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let ctx = Context { let ctx = Context {
ctx, ctx,
args: slice::from_raw_parts(argv, argc as usize), args: slice::from_raw_parts(argv, argc as usize),
}; };
let boxed_f: *mut F = ffi::sqlite3_user_data(ctx.ctx) as *mut F; (*boxed_f)(&ctx)
assert!(!boxed_f.is_null(), "Internal error - null function pointer"); });
let t = match r {
let t = (*boxed_f)(&ctx); Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t)); let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t { match t {
Ok(Ok(ref value)) => set_result(ctx.ctx, value), Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx.ctx, &err), Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx.ctx, err), Err(err) => report_error(ctx, err),
} }
} }
@ -355,6 +366,7 @@ impl InnerConnection {
aggr: D, aggr: D,
) -> Result<()> ) -> Result<()>
where where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: ToSql,
{ {
@ -374,15 +386,10 @@ impl InnerConnection {
argc: c_int, argc: c_int,
argv: *mut *mut sqlite3_value, argv: *mut *mut sqlite3_value,
) where ) where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: ToSql,
{ {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
Some(pac) => pac, Some(pac) => pac,
None => { None => {
@ -391,32 +398,40 @@ impl InnerConnection {
} }
}; };
if (*pac as *mut A).is_null() { let r = catch_unwind(|| {
*pac = Box::into_raw(Box::new((*boxed_aggr).init()));
}
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
match (*boxed_aggr).step(&mut ctx, &mut **pac) {
Ok(_) => {}
Err(err) => report_error(ctx.ctx, &err),
};
}
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where
D: Aggregate<A, T>,
T: ToSql,
{
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!( assert!(
!boxed_aggr.is_null(), !boxed_aggr.is_null(),
"Internal error - null aggregate pointer" "Internal error - null aggregate pointer"
); );
if (*pac as *mut A).is_null() {
*pac = Box::into_raw(Box::new((*boxed_aggr).init()));
}
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
(*boxed_aggr).step(&mut ctx, &mut **pac)
});
let r = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
match r {
Ok(_) => {}
Err(err) => report_error(ctx, &err),
};
}
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
{
// Within the xFinal callback, it is customary to set N=0 in calls to // Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a: Option<A> = match aggregate_context(ctx, 0) { let a: Option<A> = match aggregate_context(ctx, 0) {
@ -431,7 +446,21 @@ impl InnerConnection {
None => None, None => None,
}; };
let t = (*boxed_aggr).finalize(a); let r = catch_unwind(|| {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
(*boxed_aggr).finalize(a)
});
let t = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t)); let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t { match t {
Ok(Ok(ref value)) => set_result(ctx, value), Ok(Ok(ref value)) => set_result(ctx, value),

View File

@ -2,6 +2,7 @@
#![allow(non_camel_case_types)] #![allow(non_camel_case_types)]
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use crate::ffi; use crate::ffi;
@ -146,8 +147,11 @@ impl InnerConnection {
where where
F: FnMut() -> bool, F: FnMut() -> bool,
{ {
let r = catch_unwind(|| {
let boxed_hook: *mut F = p_arg as *mut F; let boxed_hook: *mut F = p_arg as *mut F;
if (*boxed_hook)() { (*boxed_hook)()
});
if let Ok(true) = r {
1 1
} else { } else {
0 0
@ -192,8 +196,10 @@ impl InnerConnection {
where where
F: FnMut(), F: FnMut(),
{ {
let _ = catch_unwind(|| {
let boxed_hook: *mut F = p_arg as *mut F; let boxed_hook: *mut F = p_arg as *mut F;
(*boxed_hook)(); (*boxed_hook)();
});
} }
let free_rollback_hook = if hook.is_some() { let free_rollback_hook = if hook.is_some() {
@ -239,8 +245,6 @@ impl InnerConnection {
use std::ffi::CStr; use std::ffi::CStr;
use std::str; use std::str;
let boxed_hook: *mut F = p_arg as *mut F;
let action = Action::from(action_code); let action = Action::from(action_code);
let db_name = { let db_name = {
let c_slice = CStr::from_ptr(db_str).to_bytes(); let c_slice = CStr::from_ptr(db_str).to_bytes();
@ -251,7 +255,10 @@ impl InnerConnection {
str::from_utf8_unchecked(c_slice) str::from_utf8_unchecked(c_slice)
}; };
let _ = catch_unwind(|| {
let boxed_hook: *mut F = p_arg as *mut F;
(*boxed_hook)(action, db_name, tbl_name, row_id); (*boxed_hook)(action, db_name, tbl_name, row_id);
});
} }
let free_update_hook = if hook.is_some() { let free_update_hook = if hook.is_some() {

View File

@ -3,6 +3,7 @@
use std::ffi::{CStr, CString}; use std::ffi::{CStr, CString};
use std::mem; use std::mem;
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
use std::panic::catch_unwind;
use std::ptr; use std::ptr;
use std::time::Duration; use std::time::Duration;
@ -27,7 +28,7 @@ pub unsafe fn config_log(callback: Option<fn(c_int, &str)>) -> Result<()> {
let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) }; let callback: fn(c_int, &str) = unsafe { mem::transmute(p_arg) };
let s = String::from_utf8_lossy(c_slice); let s = String::from_utf8_lossy(c_slice);
callback(err, &s); let _ = catch_unwind(|| callback(err, &s));
} }
let rc = match callback { let rc = match callback {
@ -72,7 +73,7 @@ impl Connection {
let trace_fn: fn(&str) = mem::transmute(p_arg); let trace_fn: fn(&str) = mem::transmute(p_arg);
let c_slice = CStr::from_ptr(z_sql).to_bytes(); let c_slice = CStr::from_ptr(z_sql).to_bytes();
let s = String::from_utf8_lossy(c_slice); let s = String::from_utf8_lossy(c_slice);
trace_fn(&s); let _ = catch_unwind(|| trace_fn(&s));
} }
let c = self.db.borrow_mut(); let c = self.db.borrow_mut();
@ -106,7 +107,7 @@ impl Connection {
nanoseconds / NANOS_PER_SEC, nanoseconds / NANOS_PER_SEC,
(nanoseconds % NANOS_PER_SEC) as u32, (nanoseconds % NANOS_PER_SEC) as u32,
); );
profile_fn(&s, duration); let _ = catch_unwind(|| profile_fn(&s, duration));
} }
let c = self.db.borrow_mut(); let c = self.db.borrow_mut();

View File

@ -4,6 +4,8 @@ use std::os::raw::c_int;
#[cfg(feature = "unlock_notify")] #[cfg(feature = "unlock_notify")]
use std::os::raw::c_void; use std::os::raw::c_void;
#[cfg(feature = "unlock_notify")] #[cfg(feature = "unlock_notify")]
use std::panic::catch_unwind;
#[cfg(feature = "unlock_notify")]
use std::sync::{Condvar, Mutex}; use std::sync::{Condvar, Mutex};
use crate::ffi; use crate::ffi;
@ -42,8 +44,10 @@ unsafe extern "C" fn unlock_notify_cb(ap_arg: *mut *mut c_void, n_arg: c_int) {
use std::slice::from_raw_parts; use std::slice::from_raw_parts;
let args = from_raw_parts(ap_arg, n_arg as usize); let args = from_raw_parts(ap_arg, n_arg as usize);
for arg in args { for arg in args {
let _ = catch_unwind(|| {
let un: &mut UnlockNotification = &mut *(*arg as *mut UnlockNotification); let un: &mut UnlockNotification = &mut *(*arg as *mut UnlockNotification);
un.fired(); un.fired()
});
} }
} }