Actually fix auxdata api...

This commit is contained in:
Thom Chiovoloni 2020-04-12 19:41:01 -07:00 committed by Thom Chiovoloni
parent 71b2f5187b
commit 2ef3628dac

View File

@ -12,6 +12,8 @@
//! use regex::Regex;
//! use rusqlite::functions::FunctionFlags;
//! use rusqlite::{Connection, Error, Result, NO_PARAMS};
//! use std::sync::Arc;
//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
//!
//! fn add_regexp_function(db: &Connection) -> Result<()> {
//! db.create_scalar_function(
@ -20,34 +22,19 @@
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
//!
//! let saved_re: Option<&Regex> = ctx.get_aux(0)?;
//! let new_re = match saved_re {
//! None => {
//! let s = ctx.get::<String>(0)?;
//! match Regex::new(&s) {
//! Ok(r) => Some(r),
//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
//! }
//! }
//! Some(_) => None,
//! };
//!
//! let regexp: Arc<Regex> = ctx
//! .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
//! Ok(Regex::new(vr.as_str()?)?)
//! })?;
//! let is_match = {
//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
//!
//! let text = ctx
//! .get_raw(1)
//! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?;
//!
//! re.is_match(text)
//! regexp.is_match(text)
//! };
//!
//! if let Some(re) = new_re {
//! ctx.set_aux(0, re);
//! }
//!
//! Ok(is_match)
//! },
//! )
@ -67,11 +54,12 @@
//! Ok(())
//! }
//! ```
use std::any::TypeId;
use std::any::Any;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr;
use std::slice;
use std::sync::Arc;
use crate::ffi;
use crate::ffi::sqlite3_context;
@ -121,6 +109,7 @@ unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
pub struct Context<'a> {
ctx: *mut sqlite3_context,
args: &'a [*mut sqlite3_value],
// conn: PhantomData<&'conn mut Connection>,
}
impl Context<'_> {
@ -174,47 +163,60 @@ impl Context<'_> {
unsafe { ValueRef::from_value(arg) }
}
pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
where
T: Send + Sync + 'static,
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
F: FnOnce(ValueRef<'_>) -> Result<T, E>,
{
if let Some(v) = self.get_aux(arg)? {
Ok(v)
} else {
let vr = self.get_raw(arg as usize);
self.set_aux(
arg,
func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
)
}
}
/// Sets the auxilliary data associated with a particular parameter. See
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
/// this feature, or the unit tests of this module for an example.
pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) {
let boxed = Box::into_raw(Box::new(AuxData {
id: TypeId::of::<T>(),
value,
}));
pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
let orig: Arc<T> = Arc::new(value);
let inner: AuxInner = orig.clone();
let outer = Box::new(inner);
let raw: *mut AuxInner = Box::into_raw(outer);
unsafe {
ffi::sqlite3_set_auxdata(
self.ctx,
arg,
boxed as *mut c_void,
Some(free_boxed_value::<AuxData<T>>),
raw as *mut _,
Some(free_boxed_value::<AuxInner>),
)
};
Ok(orig)
}
/// Gets the auxilliary data that was associated with a given parameter
/// via `set_aux`. Returns `Ok(None)` if no data has been associated,
/// and .
pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxData<T> };
/// Gets the auxilliary data that was associated with a given parameter via
/// `set_aux`. Returns `Ok(None)` if no data has been associated, and
/// Ok(Some(v)) if it has. Returns an error if the requested type does not
/// match.
pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
if p.is_null() {
Ok(None)
} else {
let id = unsafe { (*p).id };
if TypeId::of::<T>() != id {
Err(Error::GetAuxWrongType)
} else {
Ok(Some(unsafe { &(*p).value }))
}
let v: AuxInner = AuxInner::clone(unsafe { &*p });
v.downcast::<T>()
.map(Some)
.map_err(|_| Error::GetAuxWrongType)
}
}
}
#[repr(C)]
struct AuxData<T: 'static> {
id: TypeId,
value: T,
}
type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
/// `feature = "functions"` Aggregate is the callback interface for user-defined
/// aggregate function.
@ -776,34 +778,21 @@ mod test {
// expression multiple times within one query.
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
let saved_re: Option<&Regex> = ctx.get_aux(0)?;
let new_re = match saved_re {
None => {
let s = ctx.get::<String>(0)?;
match Regex::new(&s) {
Ok(r) => Some(r),
Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
}
}
Some(_) => None,
};
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
let regexp: std::sync::Arc<Regex> = ctx
.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
Ok(Regex::new(vr.as_str()?)?)
})?;
let is_match = {
let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
let text = ctx
.get_raw(1)
.as_str()
.map_err(|e| Error::UserFunctionError(e.into()))?;
re.is_match(text)
regexp.is_match(text)
};
if let Some(re) = new_re {
ctx.set_aux(0, re);
}
Ok(is_match)
}
@ -878,10 +867,10 @@ mod test {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100);
ctx.set_aux::<i64>(0, 100)?;
} else {
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
assert_eq!(ctx.get_aux::<i64>(0), Ok(Some(&100)));
assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
}
Ok(true)
})