mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-31 05:48:56 +08:00 
			
		
		
		
	Introduce SqlFnOutput trait
To keep compatibility with existing code
This commit is contained in:
		
							
								
								
									
										116
									
								
								src/functions.rs
									
									
									
									
									
								
							
							
						
						
									
										116
									
								
								src/functions.rs
									
									
									
									
									
								
							| @@ -34,7 +34,7 @@ | ||||
| //!                 regexp.is_match(text) | ||||
| //!             }; | ||||
| //! | ||||
| //!             Ok((is_match, None)) | ||||
| //!             Ok(is_match) | ||||
| //!         }, | ||||
| //!     ) | ||||
| //! } | ||||
| @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context; | ||||
| use crate::ffi::sqlite3_value; | ||||
|  | ||||
| use crate::context::set_result; | ||||
| use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; | ||||
| use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; | ||||
|  | ||||
| use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; | ||||
|  | ||||
| @@ -256,6 +256,33 @@ type AuxInner = Arc<dyn Any + Send + Sync + 'static>; | ||||
| /// Subtype of an SQL function | ||||
| pub type SubType = Option<std::os::raw::c_uint>; | ||||
|  | ||||
| /// Result of an SQL function | ||||
| pub trait SqlFnOutput { | ||||
|     /// Converts Rust value to SQLite value with an optional sub-type | ||||
|     fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>; | ||||
| } | ||||
|  | ||||
| impl<T: ToSql> SqlFnOutput for T { | ||||
|     fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { | ||||
|         ToSql::to_sql(self).map(|o| (o, None)) | ||||
|     } | ||||
| } | ||||
|  | ||||
| unsafe fn sql_result<T: SqlFnOutput>(ctx: *mut sqlite3_context, r: Result<T>) { | ||||
|     let t = r.as_ref().map(SqlFnOutput::to_sql); | ||||
|  | ||||
|     match t { | ||||
|         Ok(Ok((ref value, sub_type))) => { | ||||
|             set_result(ctx, value); | ||||
|             if let Some(sub_type) = sub_type { | ||||
|                 ffi::sqlite3_result_subtype(ctx, sub_type); | ||||
|             } | ||||
|         } | ||||
|         Ok(Err(err)) => report_error(ctx, &err), | ||||
|         Err(err) => report_error(ctx, err), | ||||
|     }; | ||||
| } | ||||
|  | ||||
| /// Aggregate is the callback interface for user-defined | ||||
| /// aggregate function. | ||||
| /// | ||||
| @@ -264,7 +291,7 @@ pub type SubType = Option<std::os::raw::c_uint>; | ||||
| pub trait Aggregate<A, T> | ||||
| where | ||||
|     A: RefUnwindSafe + UnwindSafe, | ||||
|     T: ToSql, | ||||
|     T: SqlFnOutput, | ||||
| { | ||||
|     /// Initializes the aggregation context. Will be called prior to the first | ||||
|     /// call to [`step()`](Aggregate::step) to set up the context for an | ||||
| @@ -285,7 +312,7 @@ where | ||||
|     /// given `None`. | ||||
|     /// | ||||
|     /// The passed context will have no arguments. | ||||
|     fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<(T, SubType)>; | ||||
|     fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>; | ||||
| } | ||||
|  | ||||
| /// `WindowAggregate` is the callback interface for | ||||
| @@ -295,11 +322,11 @@ where | ||||
| pub trait WindowAggregate<A, T>: Aggregate<A, T> | ||||
| where | ||||
|     A: RefUnwindSafe + UnwindSafe, | ||||
|     T: ToSql, | ||||
|     T: SqlFnOutput, | ||||
| { | ||||
|     /// Returns the current value of the aggregate. Unlike xFinal, the | ||||
|     /// implementation should not delete any context. | ||||
|     fn value(&self, acc: Option<&mut A>) -> Result<(T, SubType)>; | ||||
|     fn value(&self, acc: Option<&mut A>) -> Result<T>; | ||||
|  | ||||
|     /// Removes a row from the current window. | ||||
|     fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; | ||||
| @@ -365,7 +392,7 @@ impl Connection { | ||||
|     ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, | ||||
|     ///         |ctx| { | ||||
|     ///             let value = ctx.get::<f64>(0)?; | ||||
|     ///             Ok((value / 2f64, None)) | ||||
|     ///             Ok(value / 2f64) | ||||
|     ///         }, | ||||
|     ///     )?; | ||||
|     /// | ||||
| @@ -387,8 +414,8 @@ impl Connection { | ||||
|         x_func: F, | ||||
|     ) -> Result<()> | ||||
|     where | ||||
|         F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, | ||||
|         T: ToSql, | ||||
|         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, | ||||
|         T: SqlFnOutput, | ||||
|     { | ||||
|         self.db | ||||
|             .borrow_mut() | ||||
| @@ -412,7 +439,7 @@ impl Connection { | ||||
|     where | ||||
|         A: RefUnwindSafe + UnwindSafe, | ||||
|         D: Aggregate<A, T> + 'static, | ||||
|         T: ToSql, | ||||
|         T: SqlFnOutput, | ||||
|     { | ||||
|         self.db | ||||
|             .borrow_mut() | ||||
| @@ -469,16 +496,16 @@ impl InnerConnection { | ||||
|         x_func: F, | ||||
|     ) -> Result<()> | ||||
|     where | ||||
|         F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, | ||||
|         T: ToSql, | ||||
|         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, | ||||
|         T: SqlFnOutput, | ||||
|     { | ||||
|         unsafe extern "C" fn call_boxed_closure<F, T>( | ||||
|             ctx: *mut sqlite3_context, | ||||
|             argc: c_int, | ||||
|             argv: *mut *mut sqlite3_value, | ||||
|         ) where | ||||
|             F: FnMut(&Context<'_>) -> Result<(T, SubType)>, | ||||
|             T: ToSql, | ||||
|             F: FnMut(&Context<'_>) -> Result<T>, | ||||
|             T: SqlFnOutput, | ||||
|         { | ||||
|             let r = catch_unwind(|| { | ||||
|                 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); | ||||
| @@ -496,18 +523,7 @@ impl InnerConnection { | ||||
|                 } | ||||
|                 Ok(r) => r, | ||||
|             }; | ||||
|             let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); | ||||
|  | ||||
|             match t { | ||||
|                 Ok((Ok(ref value), sub_type)) => { | ||||
|                     set_result(ctx, value); | ||||
|                     if let Some(sub_type) = sub_type { | ||||
|                         ffi::sqlite3_result_subtype(ctx, *sub_type); | ||||
|                     } | ||||
|                 } | ||||
|                 Ok((Err(err), _)) => report_error(ctx, &err), | ||||
|                 Err(err) => report_error(ctx, err), | ||||
|             } | ||||
|             sql_result(ctx, t); | ||||
|         } | ||||
|  | ||||
|         let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); | ||||
| @@ -538,7 +554,7 @@ impl InnerConnection { | ||||
|     where | ||||
|         A: RefUnwindSafe + UnwindSafe, | ||||
|         D: Aggregate<A, T> + 'static, | ||||
|         T: ToSql, | ||||
|         T: SqlFnOutput, | ||||
|     { | ||||
|         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); | ||||
|         let c_name = str_to_cstring(fn_name)?; | ||||
| @@ -624,7 +640,7 @@ unsafe extern "C" fn call_boxed_step<A, D, T>( | ||||
| ) where | ||||
|     A: RefUnwindSafe + UnwindSafe, | ||||
|     D: Aggregate<A, T>, | ||||
|     T: ToSql, | ||||
|     T: SqlFnOutput, | ||||
| { | ||||
|     let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { | ||||
|         pac | ||||
| @@ -710,7 +726,7 @@ unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) | ||||
| where | ||||
|     A: RefUnwindSafe + UnwindSafe, | ||||
|     D: Aggregate<A, T>, | ||||
|     T: ToSql, | ||||
|     T: SqlFnOutput, | ||||
| { | ||||
|     // Within the xFinal callback, it is customary to set N=0 in calls to | ||||
|     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. | ||||
| @@ -744,17 +760,7 @@ where | ||||
|         } | ||||
|         Ok(r) => r, | ||||
|     }; | ||||
|     let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); | ||||
|     match t { | ||||
|         Ok((Ok(ref value), sub_type)) => { | ||||
|             set_result(ctx, value); | ||||
|             if let Some(sub_type) = sub_type { | ||||
|                 ffi::sqlite3_result_subtype(ctx, *sub_type); | ||||
|             } | ||||
|         } | ||||
|         Ok((Err(err), _)) => report_error(ctx, &err), | ||||
|         Err(err) => report_error(ctx, err), | ||||
|     } | ||||
|     sql_result(ctx, t); | ||||
| } | ||||
|  | ||||
| #[cfg(feature = "window")] | ||||
| @@ -806,13 +812,13 @@ mod test { | ||||
|  | ||||
|     #[cfg(feature = "window")] | ||||
|     use crate::functions::WindowAggregate; | ||||
|     use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; | ||||
|     use crate::functions::{Aggregate, Context, FunctionFlags}; | ||||
|     use crate::{Connection, Error, Result}; | ||||
|  | ||||
|     fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { | ||||
|     fn half(ctx: &Context<'_>) -> Result<c_double> { | ||||
|         assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); | ||||
|         let value = ctx.get::<c_double>(0)?; | ||||
|         Ok((value / 2f64, None)) | ||||
|         Ok(value / 2f64) | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
| @@ -851,7 +857,7 @@ mod test { | ||||
|     // This implementation of a regexp scalar function uses SQLite's auxiliary data | ||||
|     // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular | ||||
|     // expression multiple times within one query. | ||||
|     fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { | ||||
|     fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> { | ||||
|         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); | ||||
|         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; | ||||
|         let regexp: std::sync::Arc<Regex> = ctx | ||||
| @@ -868,7 +874,7 @@ mod test { | ||||
|             regexp.is_match(text) | ||||
|         }; | ||||
|  | ||||
|         Ok((is_match, None)) | ||||
|         Ok(is_match) | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
| @@ -915,7 +921,7 @@ mod test { | ||||
|                     ret.push_str(&s); | ||||
|                 } | ||||
|  | ||||
|                 Ok((ret, None)) | ||||
|                 Ok(ret) | ||||
|             }, | ||||
|         )?; | ||||
|  | ||||
| @@ -940,7 +946,7 @@ mod test { | ||||
|                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); | ||||
|                 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); | ||||
|             } | ||||
|             Ok((true, None)) | ||||
|             Ok(true) | ||||
|         })?; | ||||
|  | ||||
|         let res: bool = | ||||
| @@ -963,12 +969,8 @@ mod test { | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn finalize( | ||||
|             &self, | ||||
|             _: &mut Context<'_>, | ||||
|             sum: Option<i64>, | ||||
|         ) -> Result<(Option<i64>, SubType)> { | ||||
|             Ok((sum, None)) | ||||
|         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> { | ||||
|             Ok(sum) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -982,8 +984,8 @@ mod test { | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<(i64, SubType)> { | ||||
|             Ok((sum.unwrap_or(0), None)) | ||||
|         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> { | ||||
|             Ok(sum.unwrap_or(0)) | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -1041,8 +1043,8 @@ mod test { | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn value(&self, sum: Option<&mut i64>) -> Result<(Option<i64>, SubType)> { | ||||
|             Ok((sum.copied(), None)) | ||||
|         fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> { | ||||
|             Ok(sum.copied()) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|   | ||||
| @@ -1801,7 +1801,7 @@ mod test { | ||||
|             functions::FunctionFlags::default(), | ||||
|             move |_| { | ||||
|                 interrupt_handle.interrupt(); | ||||
|                 Ok((0, None)) | ||||
|                 Ok(0) | ||||
|             }, | ||||
|         )?; | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user