Actually fix auxdata api...

This commit is contained in:
Thom Chiovoloni 2020-04-12 19:41:01 -07:00 committed by Thom Chiovoloni
parent 71b2f5187b
commit 2ef3628dac

View File

@ -12,6 +12,8 @@
//! use regex::Regex; //! use regex::Regex;
//! use rusqlite::functions::FunctionFlags; //! use rusqlite::functions::FunctionFlags;
//! use rusqlite::{Connection, Error, Result, NO_PARAMS}; //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
//! use std::sync::Arc;
//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
//! //!
//! fn add_regexp_function(db: &Connection) -> Result<()> { //! fn add_regexp_function(db: &Connection) -> Result<()> {
//! db.create_scalar_function( //! db.create_scalar_function(
@ -20,34 +22,19 @@
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| { //! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
//! //! let regexp: Arc<Regex> = ctx
//! let saved_re: Option<&Regex> = ctx.get_aux(0)?; //! .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
//! let new_re = match saved_re { //! Ok(Regex::new(vr.as_str()?)?)
//! None => { //! })?;
//! let s = ctx.get::<String>(0)?;
//! match Regex::new(&s) {
//! Ok(r) => Some(r),
//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
//! }
//! }
//! Some(_) => None,
//! };
//!
//! let is_match = { //! let is_match = {
//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
//!
//! let text = ctx //! let text = ctx
//! .get_raw(1) //! .get_raw(1)
//! .as_str() //! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?; //! .map_err(|e| Error::UserFunctionError(e.into()))?;
//! //!
//! re.is_match(text) //! regexp.is_match(text)
//! }; //! };
//! //!
//! if let Some(re) = new_re {
//! ctx.set_aux(0, re);
//! }
//!
//! Ok(is_match) //! Ok(is_match)
//! }, //! },
//! ) //! )
@ -67,11 +54,12 @@
//! Ok(()) //! Ok(())
//! } //! }
//! ``` //! ```
use std::any::TypeId; use std::any::Any;
use std::os::raw::{c_int, c_void}; use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr; use std::ptr;
use std::slice; use std::slice;
use std::sync::Arc;
use crate::ffi; use crate::ffi;
use crate::ffi::sqlite3_context; use crate::ffi::sqlite3_context;
@ -121,6 +109,7 @@ unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
pub struct Context<'a> { pub struct Context<'a> {
ctx: *mut sqlite3_context, ctx: *mut sqlite3_context,
args: &'a [*mut sqlite3_value], args: &'a [*mut sqlite3_value],
// conn: PhantomData<&'conn mut Connection>,
} }
impl Context<'_> { impl Context<'_> {
@ -174,47 +163,60 @@ impl Context<'_> {
unsafe { ValueRef::from_value(arg) } unsafe { ValueRef::from_value(arg) }
} }
pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
where
T: Send + Sync + 'static,
E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
F: FnOnce(ValueRef<'_>) -> Result<T, E>,
{
if let Some(v) = self.get_aux(arg)? {
Ok(v)
} else {
let vr = self.get_raw(arg as usize);
self.set_aux(
arg,
func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
)
}
}
/// Sets the auxilliary data associated with a particular parameter. See /// Sets the auxilliary data associated with a particular parameter. See
/// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of
/// this feature, or the unit tests of this module for an example. /// this feature, or the unit tests of this module for an example.
pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) { pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
let boxed = Box::into_raw(Box::new(AuxData { let orig: Arc<T> = Arc::new(value);
id: TypeId::of::<T>(), let inner: AuxInner = orig.clone();
value, let outer = Box::new(inner);
})); let raw: *mut AuxInner = Box::into_raw(outer);
unsafe { unsafe {
ffi::sqlite3_set_auxdata( ffi::sqlite3_set_auxdata(
self.ctx, self.ctx,
arg, arg,
boxed as *mut c_void, raw as *mut _,
Some(free_boxed_value::<AuxData<T>>), Some(free_boxed_value::<AuxInner>),
) )
}; };
Ok(orig)
} }
/// Gets the auxilliary data that was associated with a given parameter /// Gets the auxilliary data that was associated with a given parameter via
/// via `set_aux`. Returns `Ok(None)` if no data has been associated, /// `set_aux`. Returns `Ok(None)` if no data has been associated, and
/// and . /// Ok(Some(v)) if it has. Returns an error if the requested type does not
pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> { /// match.
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxData<T> }; pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
if p.is_null() { if p.is_null() {
Ok(None) Ok(None)
} else { } else {
let id = unsafe { (*p).id }; let v: AuxInner = AuxInner::clone(unsafe { &*p });
if TypeId::of::<T>() != id { v.downcast::<T>()
Err(Error::GetAuxWrongType) .map(Some)
} else { .map_err(|_| Error::GetAuxWrongType)
Ok(Some(unsafe { &(*p).value }))
}
} }
} }
} }
#[repr(C)] type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
struct AuxData<T: 'static> {
id: TypeId,
value: T,
}
/// `feature = "functions"` Aggregate is the callback interface for user-defined /// `feature = "functions"` Aggregate is the callback interface for user-defined
/// aggregate function. /// aggregate function.
@ -776,34 +778,21 @@ mod test {
// expression multiple times within one query. // expression multiple times within one query.
fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> { fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
let saved_re: Option<&Regex> = ctx.get_aux(0)?; let regexp: std::sync::Arc<Regex> = ctx
let new_re = match saved_re { .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
None => { Ok(Regex::new(vr.as_str()?)?)
let s = ctx.get::<String>(0)?; })?;
match Regex::new(&s) {
Ok(r) => Some(r),
Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
}
}
Some(_) => None,
};
let is_match = { let is_match = {
let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
let text = ctx let text = ctx
.get_raw(1) .get_raw(1)
.as_str() .as_str()
.map_err(|e| Error::UserFunctionError(e.into()))?; .map_err(|e| Error::UserFunctionError(e.into()))?;
re.is_match(text) regexp.is_match(text)
}; };
if let Some(re) = new_re {
ctx.set_aux(0, re);
}
Ok(is_match) Ok(is_match)
} }
@ -878,10 +867,10 @@ mod test {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
if !ctx.get::<bool>(1)? { if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100); ctx.set_aux::<i64>(0, 100)?;
} else { } else {
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
assert_eq!(ctx.get_aux::<i64>(0), Ok(Some(&100))); assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100);
} }
Ok(true) Ok(true)
}) })