mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-31 05:48:56 +08:00 
			
		
		
		
	Actually fix auxdata api...
This commit is contained in:
		
				
					committed by
					
						 Thom Chiovoloni
						Thom Chiovoloni
					
				
			
			
				
	
			
			
			
						parent
						
							71b2f5187b
						
					
				
				
					commit
					2ef3628dac
				
			
							
								
								
									
										119
									
								
								src/functions.rs
									
									
									
									
									
								
							
							
						
						
									
										119
									
								
								src/functions.rs
									
									
									
									
									
								
							| @@ -12,6 +12,8 @@ | |||||||
| //! use regex::Regex; | //! use regex::Regex; | ||||||
| //! use rusqlite::functions::FunctionFlags; | //! use rusqlite::functions::FunctionFlags; | ||||||
| //! use rusqlite::{Connection, Error, Result, NO_PARAMS}; | //! use rusqlite::{Connection, Error, Result, NO_PARAMS}; | ||||||
|  | //! use std::sync::Arc; | ||||||
|  | //! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; | ||||||
| //! | //! | ||||||
| //! fn add_regexp_function(db: &Connection) -> Result<()> { | //! fn add_regexp_function(db: &Connection) -> Result<()> { | ||||||
| //!     db.create_scalar_function( | //!     db.create_scalar_function( | ||||||
| @@ -20,34 +22,19 @@ | |||||||
| //!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, | //!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, | ||||||
| //!         move |ctx| { | //!         move |ctx| { | ||||||
| //!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); | //!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); | ||||||
| //! | //!             let regexp: Arc<Regex> = ctx | ||||||
| //!             let saved_re: Option<&Regex> = ctx.get_aux(0)?; | //!                 .get_or_create_aux(0, |vr| -> Result<_, BoxError> { | ||||||
| //!             let new_re = match saved_re { | //!                     Ok(Regex::new(vr.as_str()?)?) | ||||||
| //!                 None => { | //!                 })?; | ||||||
| //!                     let s = ctx.get::<String>(0)?; |  | ||||||
| //!                     match Regex::new(&s) { |  | ||||||
| //!                         Ok(r) => Some(r), |  | ||||||
| //!                         Err(err) => return Err(Error::UserFunctionError(Box::new(err))), |  | ||||||
| //!                     } |  | ||||||
| //!                 } |  | ||||||
| //!                 Some(_) => None, |  | ||||||
| //!             }; |  | ||||||
| //! |  | ||||||
| //!             let is_match = { | //!             let is_match = { | ||||||
| //!                 let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); |  | ||||||
| //! |  | ||||||
| //!                 let text = ctx | //!                 let text = ctx | ||||||
| //!                     .get_raw(1) | //!                     .get_raw(1) | ||||||
| //!                     .as_str() | //!                     .as_str() | ||||||
| //!                     .map_err(|e| Error::UserFunctionError(e.into()))?; | //!                     .map_err(|e| Error::UserFunctionError(e.into()))?; | ||||||
| //! | //! | ||||||
| //!                 re.is_match(text) | //!                 regexp.is_match(text) | ||||||
| //!             }; | //!             }; | ||||||
| //! | //! | ||||||
| //!             if let Some(re) = new_re { |  | ||||||
| //!                 ctx.set_aux(0, re); |  | ||||||
| //!             } |  | ||||||
| //! |  | ||||||
| //!             Ok(is_match) | //!             Ok(is_match) | ||||||
| //!         }, | //!         }, | ||||||
| //!     ) | //!     ) | ||||||
| @@ -67,11 +54,12 @@ | |||||||
| //!     Ok(()) | //!     Ok(()) | ||||||
| //! } | //! } | ||||||
| //! ``` | //! ``` | ||||||
| use std::any::TypeId; | use std::any::Any; | ||||||
| use std::os::raw::{c_int, c_void}; | use std::os::raw::{c_int, c_void}; | ||||||
| use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; | use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe}; | ||||||
| use std::ptr; | use std::ptr; | ||||||
| use std::slice; | use std::slice; | ||||||
|  | use std::sync::Arc; | ||||||
|  |  | ||||||
| use crate::ffi; | use crate::ffi; | ||||||
| use crate::ffi::sqlite3_context; | use crate::ffi::sqlite3_context; | ||||||
| @@ -121,6 +109,7 @@ unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) { | |||||||
| pub struct Context<'a> { | pub struct Context<'a> { | ||||||
|     ctx: *mut sqlite3_context, |     ctx: *mut sqlite3_context, | ||||||
|     args: &'a [*mut sqlite3_value], |     args: &'a [*mut sqlite3_value], | ||||||
|  |     // conn: PhantomData<&'conn mut Connection>, | ||||||
| } | } | ||||||
|  |  | ||||||
| impl Context<'_> { | impl Context<'_> { | ||||||
| @@ -174,47 +163,60 @@ impl Context<'_> { | |||||||
|         unsafe { ValueRef::from_value(arg) } |         unsafe { ValueRef::from_value(arg) } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>> | ||||||
|  |     where | ||||||
|  |         T: Send + Sync + 'static, | ||||||
|  |         E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>, | ||||||
|  |         F: FnOnce(ValueRef<'_>) -> Result<T, E>, | ||||||
|  |     { | ||||||
|  |         if let Some(v) = self.get_aux(arg)? { | ||||||
|  |             Ok(v) | ||||||
|  |         } else { | ||||||
|  |             let vr = self.get_raw(arg as usize); | ||||||
|  |             self.set_aux( | ||||||
|  |                 arg, | ||||||
|  |                 func(vr).map_err(|e| Error::UserFunctionError(e.into()))?, | ||||||
|  |             ) | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     /// Sets the auxilliary data associated with a particular parameter. See |     /// Sets the auxilliary data associated with a particular parameter. See | ||||||
|     /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of |     /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of | ||||||
|     /// this feature, or the unit tests of this module for an example. |     /// this feature, or the unit tests of this module for an example. | ||||||
|     pub fn set_aux<T: 'static>(&self, arg: c_int, value: T) { |     pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> { | ||||||
|         let boxed = Box::into_raw(Box::new(AuxData { |         let orig: Arc<T> = Arc::new(value); | ||||||
|             id: TypeId::of::<T>(), |         let inner: AuxInner = orig.clone(); | ||||||
|             value, |         let outer = Box::new(inner); | ||||||
|         })); |         let raw: *mut AuxInner = Box::into_raw(outer); | ||||||
|         unsafe { |         unsafe { | ||||||
|             ffi::sqlite3_set_auxdata( |             ffi::sqlite3_set_auxdata( | ||||||
|                 self.ctx, |                 self.ctx, | ||||||
|                 arg, |                 arg, | ||||||
|                 boxed as *mut c_void, |                 raw as *mut _, | ||||||
|                 Some(free_boxed_value::<AuxData<T>>), |                 Some(free_boxed_value::<AuxInner>), | ||||||
|             ) |             ) | ||||||
|         }; |         }; | ||||||
|  |         Ok(orig) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     /// Gets the auxilliary data that was associated with a given parameter |     /// Gets the auxilliary data that was associated with a given parameter via | ||||||
|     /// via `set_aux`. Returns `Ok(None)` if no data has been associated, |     /// `set_aux`. Returns `Ok(None)` if no data has been associated, and | ||||||
|     /// and . |     /// Ok(Some(v)) if it has. Returns an error if the requested type does not | ||||||
|     pub fn get_aux<T: 'static>(&self, arg: c_int) -> Result<Option<&T>> { |     /// match. | ||||||
|         let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxData<T> }; |     pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> { | ||||||
|  |         let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner }; | ||||||
|         if p.is_null() { |         if p.is_null() { | ||||||
|             Ok(None) |             Ok(None) | ||||||
|         } else { |         } else { | ||||||
|             let id = unsafe { (*p).id }; |             let v: AuxInner = AuxInner::clone(unsafe { &*p }); | ||||||
|             if TypeId::of::<T>() != id { |             v.downcast::<T>() | ||||||
|                 Err(Error::GetAuxWrongType) |                 .map(Some) | ||||||
|             } else { |                 .map_err(|_| Error::GetAuxWrongType) | ||||||
|                 Ok(Some(unsafe { &(*p).value })) |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| #[repr(C)] | type AuxInner = Arc<dyn Any + Send + Sync + 'static>; | ||||||
| struct AuxData<T: 'static> { |  | ||||||
|     id: TypeId, |  | ||||||
|     value: T, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| /// `feature = "functions"` Aggregate is the callback interface for user-defined | /// `feature = "functions"` Aggregate is the callback interface for user-defined | ||||||
| /// aggregate function. | /// aggregate function. | ||||||
| @@ -776,34 +778,21 @@ mod test { | |||||||
|     // expression multiple times within one query. |     // expression multiple times within one query. | ||||||
|     fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> { |     fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<bool> { | ||||||
|         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); |         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); | ||||||
|  |         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; | ||||||
|         let saved_re: Option<&Regex> = ctx.get_aux(0)?; |         let regexp: std::sync::Arc<Regex> = ctx | ||||||
|         let new_re = match saved_re { |             .get_or_create_aux(0, |vr| -> Result<_, BoxError> { | ||||||
|             None => { |                 Ok(Regex::new(vr.as_str()?)?) | ||||||
|                 let s = ctx.get::<String>(0)?; |             })?; | ||||||
|                 match Regex::new(&s) { |  | ||||||
|                     Ok(r) => Some(r), |  | ||||||
|                     Err(err) => return Err(Error::UserFunctionError(Box::new(err))), |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             Some(_) => None, |  | ||||||
|         }; |  | ||||||
|  |  | ||||||
|         let is_match = { |         let is_match = { | ||||||
|             let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); |  | ||||||
|  |  | ||||||
|             let text = ctx |             let text = ctx | ||||||
|                 .get_raw(1) |                 .get_raw(1) | ||||||
|                 .as_str() |                 .as_str() | ||||||
|                 .map_err(|e| Error::UserFunctionError(e.into()))?; |                 .map_err(|e| Error::UserFunctionError(e.into()))?; | ||||||
|  |  | ||||||
|             re.is_match(text) |             regexp.is_match(text) | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         if let Some(re) = new_re { |  | ||||||
|             ctx.set_aux(0, re); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         Ok(is_match) |         Ok(is_match) | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -878,10 +867,10 @@ mod test { | |||||||
|         let db = Connection::open_in_memory().unwrap(); |         let db = Connection::open_in_memory().unwrap(); | ||||||
|         db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { |         db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { | ||||||
|             if !ctx.get::<bool>(1)? { |             if !ctx.get::<bool>(1)? { | ||||||
|                 ctx.set_aux::<i64>(0, 100); |                 ctx.set_aux::<i64>(0, 100)?; | ||||||
|             } else { |             } else { | ||||||
|                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); |                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); | ||||||
|                 assert_eq!(ctx.get_aux::<i64>(0), Ok(Some(&100))); |                 assert_eq!(*ctx.get_aux::<i64>(0).unwrap().unwrap(), 100); | ||||||
|             } |             } | ||||||
|             Ok(true) |             Ok(true) | ||||||
|         }) |         }) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user