Make get_aux safe by storing the TypeId

This commit is contained in:
Thom Chiovoloni
2019-01-25 00:43:50 -08:00
parent 6d1c915c2b
commit b6539a0fbf
2 changed files with 51 additions and 15 deletions

View File

@@ -158,32 +158,32 @@ impl<'a> Context<'a> {
/// Sets the auxilliary data associated with a particular parameter. See
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
/// this feature, or the unit tests of this module for an example.
pub fn set_aux<T>(&self, arg: c_int, value: T) {
let boxed = Box::into_raw(Box::new(value));
pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) {
let boxed = Box::into_raw(Box::new((std::any::TypeId::of::<T>(), value)));
unsafe {
ffi::sqlite3_set_auxdata(
self.ctx,
arg,
boxed as *mut c_void,
Some(free_boxed_value::<T>),
Some(free_boxed_value::<(std::any::TypeId, T)>),
)
};
}
/// Gets the auxilliary data that was associated with a given parameter
/// via `set_aux`. Returns `None` if no data has been associated.
///
/// # Unsafety
///
/// This function is unsafe as there is no guarantee that the type `T`
/// requested matches the type `T` that was provided to `set_aux`. The
/// types must be identical.
pub unsafe fn get_aux<T>(&self, arg: c_int) -> Option<&T> {
let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T;
/// via `set_aux`. Returns `Ok(None)` if no data has been associated,
/// and .
pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut (std::any::TypeId, T) };
if p.is_null() {
None
Ok(None)
} else {
Some(&*p)
let id_val = unsafe { &*p };
if std::any::TypeId::of::<T>() != id_val.0 {
Err(Error::GetAuxWrongType)
} else {
Ok(Some(&id_val.1))
}
}
}
}
@@ -559,7 +559,7 @@ mod test {
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
assert!(ctx.len() == 2, "called with unexpected number of arguments");
let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) };
let saved_re: Option<&Regex> = ctx.get_aux(0)?;
let new_re = match saved_re {
None => {
let s = ctx.get::<String>(0)?;
@@ -699,6 +699,28 @@ mod test {
}
}
#[test]
fn test_get_aux_type_checking() {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("example", 2, false, |ctx| {
if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100);
} else {
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
assert_eq!(ctx.get_aux::<i64>(0), Ok(Some(&100)));
}
Ok(true)
})
.unwrap();
let res: bool = db.query_row(
"SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
NO_PARAMS,
|r| r.get(0)).unwrap();
// Doesn't actually matter, we'll assert in the function if there's a problem.
assert!(res);
}
struct Sum;
struct Count;