From 2ef3628dac35aeba0a97d5fb3a57746b4e1d62b3 Mon Sep 17 00:00:00 2001 From: Thom Chiovoloni Date: Sun, 12 Apr 2020 19:41:01 -0700 Subject: [PATCH] Actually fix auxdata api... --- src/functions.rs | 119 +++++++++++++++++++++-------------------------- 1 file changed, 54 insertions(+), 65 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index df40b18..0c76244 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -12,6 +12,8 @@ //! use regex::Regex; //! use rusqlite::functions::FunctionFlags; //! use rusqlite::{Connection, Error, Result, NO_PARAMS}; +//! use std::sync::Arc; +//! type BoxError = Box; //! //! fn add_regexp_function(db: &Connection) -> Result<()> { //! db.create_scalar_function( @@ -20,34 +22,19 @@ //! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, //! move |ctx| { //! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); -//! -//! let saved_re: Option<&Regex> = ctx.get_aux(0)?; -//! let new_re = match saved_re { -//! None => { -//! let s = ctx.get::(0)?; -//! match Regex::new(&s) { -//! Ok(r) => Some(r), -//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))), -//! } -//! } -//! Some(_) => None, -//! }; -//! +//! let regexp: Arc = ctx +//! .get_or_create_aux(0, |vr| -> Result<_, BoxError> { +//! Ok(Regex::new(vr.as_str()?)?) +//! })?; //! let is_match = { -//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); -//! //! let text = ctx //! .get_raw(1) //! .as_str() //! .map_err(|e| Error::UserFunctionError(e.into()))?; //! -//! re.is_match(text) +//! regexp.is_match(text) //! }; //! -//! if let Some(re) = new_re { -//! ctx.set_aux(0, re); -//! } -//! //! Ok(is_match) //! }, //! ) @@ -67,11 +54,12 @@ //! Ok(()) //! } //! ``` -use std::any::TypeId; +use std::any::Any; use std::os::raw::{c_int, c_void}; use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; use std::ptr; use std::slice; +use std::sync::Arc; use crate::ffi; use crate::ffi::sqlite3_context; @@ -121,6 +109,7 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { pub struct Context<'a> { ctx: *mut sqlite3_context, args: &'a [*mut sqlite3_value], + // conn: PhantomData<&'conn mut Connection>, } impl Context<'_> { @@ -174,47 +163,60 @@ impl Context<'_> { unsafe { ValueRef::from_value(arg) } } + pub fn get_or_create_aux(&self, arg: c_int, func: F) -> Result> + where + T: Send + Sync + 'static, + E: Into>, + F: FnOnce(ValueRef<'_>) -> Result, + { + if let Some(v) = self.get_aux(arg)? { + Ok(v) + } else { + let vr = self.get_raw(arg as usize); + self.set_aux( + arg, + func(vr).map_err(|e| Error::UserFunctionError(e.into()))?, + ) + } + } + /// Sets the auxilliary data associated with a particular parameter. See /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of /// this feature, or the unit tests of this module for an example. - pub fn set_aux(&self, arg: c_int, value: T) { - let boxed = Box::into_raw(Box::new(AuxData { - id: TypeId::of::(), - value, - })); + pub fn set_aux(&self, arg: c_int, value: T) -> Result> { + let orig: Arc = Arc::new(value); + let inner: AuxInner = orig.clone(); + let outer = Box::new(inner); + let raw: *mut AuxInner = Box::into_raw(outer); unsafe { ffi::sqlite3_set_auxdata( self.ctx, arg, - boxed as *mut c_void, - Some(free_boxed_value::>), + raw as *mut _, + Some(free_boxed_value::), ) }; + Ok(orig) } - /// Gets the auxilliary data that was associated with a given parameter - /// via `set_aux`. Returns `Ok(None)` if no data has been associated, - /// and . - pub fn get_aux(&self, arg: c_int) -> Result> { - let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxData }; + /// Gets the auxilliary data that was associated with a given parameter via + /// `set_aux`. Returns `Ok(None)` if no data has been associated, and + /// Ok(Some(v)) if it has. Returns an error if the requested type does not + /// match. + pub fn get_aux(&self, arg: c_int) -> Result>> { + let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; if p.is_null() { Ok(None) } else { - let id = unsafe { (*p).id }; - if TypeId::of::() != id { - Err(Error::GetAuxWrongType) - } else { - Ok(Some(unsafe { &(*p).value })) - } + let v: AuxInner = AuxInner::clone(unsafe { &*p }); + v.downcast::() + .map(Some) + .map_err(|_| Error::GetAuxWrongType) } } } -#[repr(C)] -struct AuxData { - id: TypeId, - value: T, -} +type AuxInner = Arc; /// `feature = "functions"` Aggregate is the callback interface for user-defined /// aggregate function. @@ -776,34 +778,21 @@ mod test { // expression multiple times within one query. fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); - - let saved_re: Option<&Regex> = ctx.get_aux(0)?; - let new_re = match saved_re { - None => { - let s = ctx.get::(0)?; - match Regex::new(&s) { - Ok(r) => Some(r), - Err(err) => return Err(Error::UserFunctionError(Box::new(err))), - } - } - Some(_) => None, - }; + type BoxError = Box; + let regexp: std::sync::Arc = ctx + .get_or_create_aux(0, |vr| -> Result<_, BoxError> { + Ok(Regex::new(vr.as_str()?)?) + })?; let is_match = { - let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); - let text = ctx .get_raw(1) .as_str() .map_err(|e| Error::UserFunctionError(e.into()))?; - re.is_match(text) + regexp.is_match(text) }; - if let Some(re) = new_re { - ctx.set_aux(0, re); - } - Ok(is_match) } @@ -878,10 +867,10 @@ mod test { let db = Connection::open_in_memory().unwrap(); db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { if !ctx.get::(1)? { - ctx.set_aux::(0, 100); + ctx.set_aux::(0, 100)?; } else { assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); - assert_eq!(ctx.get_aux::(0), Ok(Some(&100))); + assert_eq!(*ctx.get_aux::(0).unwrap().unwrap(), 100); } Ok(true) })