diff --git a/src/error.rs b/src/error.rs index 2da7426..0ca03c0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -94,6 +94,11 @@ pub enum Error { #[cfg(feature = "functions")] UnwindingPanic, + + /// An error returned when `Context::get_aux` attempts to retrieve data + /// of a different type than what had been stored using `Context::set_aux`. + #[cfg(feature = "functions")] + GetAuxWrongType, } impl PartialEq for Error { @@ -131,6 +136,8 @@ impl PartialEq for Error { (Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2, #[cfg(feature = "functions")] (Error::UnwindingPanic, Error::UnwindingPanic) => true, + #[cfg(feature = "functions")] + (Error::GetAuxWrongType, Error::GetAuxWrongType) => true, (_, _) => false, } } @@ -196,6 +203,8 @@ impl fmt::Display for Error { Error::ModuleError(ref desc) => write!(f, "{}", desc), #[cfg(feature = "functions")] Error::UnwindingPanic => write!(f, "unwinding panic"), + #[cfg(feature = "functions")] + Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), } } } @@ -235,6 +244,8 @@ impl error::Error for Error { Error::ModuleError(ref desc) => desc, #[cfg(feature = "functions")] Error::UnwindingPanic => "unwinding panic", + #[cfg(feature = "functions")] + Error::GetAuxWrongType => "get_aux called with wrong type", } } @@ -272,6 +283,9 @@ impl error::Error for Error { #[cfg(feature = "functions")] Error::UnwindingPanic => None, + + #[cfg(feature = "functions")] + Error::GetAuxWrongType => None, } } } diff --git a/src/functions.rs b/src/functions.rs index c4f398d..90ef6da 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -158,32 +158,32 @@ impl<'a> Context<'a> { /// Sets the auxilliary data associated with a particular parameter. See /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of /// this feature, or the unit tests of this module for an example. - pub fn set_aux(&self, arg: c_int, value: T) { - let boxed = Box::into_raw(Box::new(value)); + pub fn set_aux(&self, arg: c_int, value: T) { + let boxed = Box::into_raw(Box::new((std::any::TypeId::of::(), value))); unsafe { ffi::sqlite3_set_auxdata( self.ctx, arg, boxed as *mut c_void, - Some(free_boxed_value::), + Some(free_boxed_value::<(std::any::TypeId, T)>), ) }; } /// Gets the auxilliary data that was associated with a given parameter - /// via `set_aux`. Returns `None` if no data has been associated. - /// - /// # Unsafety - /// - /// This function is unsafe as there is no guarantee that the type `T` - /// requested matches the type `T` that was provided to `set_aux`. The - /// types must be identical. - pub unsafe fn get_aux(&self, arg: c_int) -> Option<&T> { - let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T; + /// via `set_aux`. Returns `Ok(None)` if no data has been associated, + /// and . + pub fn get_aux(&self, arg: c_int) -> Result> { + let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut (std::any::TypeId, T) }; if p.is_null() { - None + Ok(None) } else { - Some(&*p) + let id_val = unsafe { &*p }; + if std::any::TypeId::of::() != id_val.0 { + Err(Error::GetAuxWrongType) + } else { + Ok(Some(&id_val.1)) + } } } } @@ -559,7 +559,7 @@ mod test { fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result { assert!(ctx.len() == 2, "called with unexpected number of arguments"); - let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; + let saved_re: Option<&Regex> = ctx.get_aux(0)?; let new_re = match saved_re { None => { let s = ctx.get::(0)?; @@ -699,6 +699,28 @@ mod test { } } + #[test] + fn test_get_aux_type_checking() { + let db = Connection::open_in_memory().unwrap(); + db.create_scalar_function("example", 2, false, |ctx| { + if !ctx.get::(1)? { + ctx.set_aux::(0, 100); + } else { + assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); + assert_eq!(ctx.get_aux::(0), Ok(Some(&100))); + } + Ok(true) + }) + .unwrap(); + + let res: bool = db.query_row( + "SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)", + NO_PARAMS, + |r| r.get(0)).unwrap(); + // Doesn't actually matter, we'll assert in the function if there's a problem. + assert!(res); + } + struct Sum; struct Count;