mirror of
https://github.com/isar/rusqlite.git
synced 2024-11-23 00:39:20 +08:00
Merge pull request #463 from thomcc/safe-get-aux
Make get_aux safe by storing the TypeId with the data.
This commit is contained in:
commit
36846387be
14
src/error.rs
14
src/error.rs
@ -94,6 +94,11 @@ pub enum Error {
|
|||||||
|
|
||||||
#[cfg(feature = "functions")]
|
#[cfg(feature = "functions")]
|
||||||
UnwindingPanic,
|
UnwindingPanic,
|
||||||
|
|
||||||
|
/// An error returned when `Context::get_aux` attempts to retrieve data
|
||||||
|
/// of a different type than what had been stored using `Context::set_aux`.
|
||||||
|
#[cfg(feature = "functions")]
|
||||||
|
GetAuxWrongType,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PartialEq for Error {
|
impl PartialEq for Error {
|
||||||
@ -131,6 +136,8 @@ impl PartialEq for Error {
|
|||||||
(Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2,
|
(Error::ModuleError(s1), Error::ModuleError(s2)) => s1 == s2,
|
||||||
#[cfg(feature = "functions")]
|
#[cfg(feature = "functions")]
|
||||||
(Error::UnwindingPanic, Error::UnwindingPanic) => true,
|
(Error::UnwindingPanic, Error::UnwindingPanic) => true,
|
||||||
|
#[cfg(feature = "functions")]
|
||||||
|
(Error::GetAuxWrongType, Error::GetAuxWrongType) => true,
|
||||||
(_, _) => false,
|
(_, _) => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -196,6 +203,8 @@ impl fmt::Display for Error {
|
|||||||
Error::ModuleError(ref desc) => write!(f, "{}", desc),
|
Error::ModuleError(ref desc) => write!(f, "{}", desc),
|
||||||
#[cfg(feature = "functions")]
|
#[cfg(feature = "functions")]
|
||||||
Error::UnwindingPanic => write!(f, "unwinding panic"),
|
Error::UnwindingPanic => write!(f, "unwinding panic"),
|
||||||
|
#[cfg(feature = "functions")]
|
||||||
|
Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -235,6 +244,8 @@ impl error::Error for Error {
|
|||||||
Error::ModuleError(ref desc) => desc,
|
Error::ModuleError(ref desc) => desc,
|
||||||
#[cfg(feature = "functions")]
|
#[cfg(feature = "functions")]
|
||||||
Error::UnwindingPanic => "unwinding panic",
|
Error::UnwindingPanic => "unwinding panic",
|
||||||
|
#[cfg(feature = "functions")]
|
||||||
|
Error::GetAuxWrongType => "get_aux called with wrong type",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -272,6 +283,9 @@ impl error::Error for Error {
|
|||||||
|
|
||||||
#[cfg(feature = "functions")]
|
#[cfg(feature = "functions")]
|
||||||
Error::UnwindingPanic => None,
|
Error::UnwindingPanic => None,
|
||||||
|
|
||||||
|
#[cfg(feature = "functions")]
|
||||||
|
Error::GetAuxWrongType => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -158,32 +158,32 @@ impl<'a> Context<'a> {
|
|||||||
/// Sets the auxilliary data associated with a particular parameter. See
|
/// Sets the auxilliary data associated with a particular parameter. See
|
||||||
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
|
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
|
||||||
/// this feature, or the unit tests of this module for an example.
|
/// this feature, or the unit tests of this module for an example.
|
||||||
pub fn set_aux<T>(&self, arg: c_int, value: T) {
|
pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) {
|
||||||
let boxed = Box::into_raw(Box::new(value));
|
let boxed = Box::into_raw(Box::new((std::any::TypeId::of::<T>(), value)));
|
||||||
unsafe {
|
unsafe {
|
||||||
ffi::sqlite3_set_auxdata(
|
ffi::sqlite3_set_auxdata(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
arg,
|
arg,
|
||||||
boxed as *mut c_void,
|
boxed as *mut c_void,
|
||||||
Some(free_boxed_value::<T>),
|
Some(free_boxed_value::<(std::any::TypeId, T)>),
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets the auxilliary data that was associated with a given parameter
|
/// Gets the auxilliary data that was associated with a given parameter
|
||||||
/// via `set_aux`. Returns `None` if no data has been associated.
|
/// via `set_aux`. Returns `Ok(None)` if no data has been associated,
|
||||||
///
|
/// and .
|
||||||
/// # Unsafety
|
pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> {
|
||||||
///
|
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut (std::any::TypeId, T) };
|
||||||
/// This function is unsafe as there is no guarantee that the type `T`
|
|
||||||
/// requested matches the type `T` that was provided to `set_aux`. The
|
|
||||||
/// types must be identical.
|
|
||||||
pub unsafe fn get_aux<T>(&self, arg: c_int) -> Option<&T> {
|
|
||||||
let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T;
|
|
||||||
if p.is_null() {
|
if p.is_null() {
|
||||||
None
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
Some(&*p)
|
let id_val = unsafe { &*p };
|
||||||
|
if std::any::TypeId::of::<T>() != id_val.0 {
|
||||||
|
Err(Error::GetAuxWrongType)
|
||||||
|
} else {
|
||||||
|
Ok(Some(&id_val.1))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -559,7 +559,7 @@ mod test {
|
|||||||
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
|
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
|
||||||
assert!(ctx.len() == 2, "called with unexpected number of arguments");
|
assert!(ctx.len() == 2, "called with unexpected number of arguments");
|
||||||
|
|
||||||
let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) };
|
let saved_re: Option<&Regex> = ctx.get_aux(0)?;
|
||||||
let new_re = match saved_re {
|
let new_re = match saved_re {
|
||||||
None => {
|
None => {
|
||||||
let s = ctx.get::<String>(0)?;
|
let s = ctx.get::<String>(0)?;
|
||||||
@ -699,6 +699,28 @@ mod test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_get_aux_type_checking() {
|
||||||
|
let db = Connection::open_in_memory().unwrap();
|
||||||
|
db.create_scalar_function("example", 2, false, |ctx| {
|
||||||
|
if !ctx.get::<bool>(1)? {
|
||||||
|
ctx.set_aux::<i64>(0, 100);
|
||||||
|
} else {
|
||||||
|
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
|
||||||
|
assert_eq!(ctx.get_aux::<i64>(0), Ok(Some(&100)));
|
||||||
|
}
|
||||||
|
Ok(true)
|
||||||
|
})
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let res: bool = db.query_row(
|
||||||
|
"SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)",
|
||||||
|
NO_PARAMS,
|
||||||
|
|r| r.get(0)).unwrap();
|
||||||
|
// Doesn't actually matter, we'll assert in the function if there's a problem.
|
||||||
|
assert!(res);
|
||||||
|
}
|
||||||
|
|
||||||
struct Sum;
|
struct Sum;
|
||||||
struct Count;
|
struct Count;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user