Ensure type use for auxdata is repr(C)

This commit is contained in:
Thom Chiovoloni 2020-04-12 11:17:56 -07:00 committed by Thom Chiovoloni
parent 38aea89809
commit 71b2f5187b

View File

@ -67,6 +67,7 @@
//! Ok(()) //! Ok(())
//! } //! }
//! ``` //! ```
use std::any::TypeId;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr; use std::ptr;
@ -177,13 +178,16 @@ impl Context<'_> {
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
/// this feature, or the unit tests of this module for an example. /// this feature, or the unit tests of this module for an example.
pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) { pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) {
let boxed = Box::into_raw(Box::new((std::any::TypeId::of::<T>(), value))); let boxed = Box::into_raw(Box::new(AuxData {
id: TypeId::of::<T>(),
value,
}));
unsafe { unsafe {
ffi::sqlite3_set_auxdata( ffi::sqlite3_set_auxdata(
self.ctx, self.ctx,
arg, arg,
boxed as *mut c_void, boxed as *mut c_void,
Some(free_boxed_value::<(std::any::TypeId, T)>), Some(free_boxed_value::<AuxData<T>>),
) )
}; };
} }
@ -192,20 +196,26 @@ impl Context<'_> {
/// via `set_aux`. Returns `Ok(None)` if no data has been associated, /// via `set_aux`. Returns `Ok(None)` if no data has been associated,
/// and . /// and .
pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> { pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut (std::any::TypeId, T) }; let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxData<T> };
if p.is_null() { if p.is_null() {
Ok(None) Ok(None)
} else { } else {
let id_val = unsafe { &*p }; let id = unsafe { (*p).id };
if std::any::TypeId::of::<T>() != id_val.0 { if TypeId::of::<T>() != id {
Err(Error::GetAuxWrongType) Err(Error::GetAuxWrongType)
} else { } else {
Ok(Some(&id_val.1)) Ok(Some(unsafe { &(*p).value }))
} }
} }
} }
} }
#[repr(C)]
struct AuxData<T: 'static> {
id: TypeId,
value: T,
}
/// `feature = "functions"` Aggregate is the callback interface for user-defined /// `feature = "functions"` Aggregate is the callback interface for user-defined
/// aggregate function. /// aggregate function.
/// ///