mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-31 13:58:55 +08:00 
			
		
		
		
	Introduce SqlFnOutput trait
To keep compatibility with existing code
This commit is contained in:
		
							
								
								
									
										116
									
								
								src/functions.rs
									
									
									
									
									
								
							
							
						
						
									
										116
									
								
								src/functions.rs
									
									
									
									
									
								
							| @@ -34,7 +34,7 @@ | |||||||
| //!                 regexp.is_match(text) | //!                 regexp.is_match(text) | ||||||
| //!             }; | //!             }; | ||||||
| //! | //! | ||||||
| //!             Ok((is_match, None)) | //!             Ok(is_match) | ||||||
| //!         }, | //!         }, | ||||||
| //!     ) | //!     ) | ||||||
| //! } | //! } | ||||||
| @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context; | |||||||
| use crate::ffi::sqlite3_value; | use crate::ffi::sqlite3_value; | ||||||
|  |  | ||||||
| use crate::context::set_result; | use crate::context::set_result; | ||||||
| use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; | use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; | ||||||
|  |  | ||||||
| use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; | use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; | ||||||
|  |  | ||||||
| @@ -256,6 +256,33 @@ type AuxInner = Arc<dyn Any + Send + Sync + 'static>; | |||||||
| /// Subtype of an SQL function | /// Subtype of an SQL function | ||||||
| pub type SubType = Option<std::os::raw::c_uint>; | pub type SubType = Option<std::os::raw::c_uint>; | ||||||
|  |  | ||||||
|  | /// Result of an SQL function | ||||||
|  | pub trait SqlFnOutput { | ||||||
|  |     /// Converts Rust value to SQLite value with an optional sub-type | ||||||
|  |     fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | impl<T: ToSql> SqlFnOutput for T { | ||||||
|  |     fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { | ||||||
|  |         ToSql::to_sql(self).map(|o| (o, None)) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | unsafe fn sql_result<T: SqlFnOutput>(ctx: *mut sqlite3_context, r: Result<T>) { | ||||||
|  |     let t = r.as_ref().map(SqlFnOutput::to_sql); | ||||||
|  |  | ||||||
|  |     match t { | ||||||
|  |         Ok(Ok((ref value, sub_type))) => { | ||||||
|  |             set_result(ctx, value); | ||||||
|  |             if let Some(sub_type) = sub_type { | ||||||
|  |                 ffi::sqlite3_result_subtype(ctx, sub_type); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         Ok(Err(err)) => report_error(ctx, &err), | ||||||
|  |         Err(err) => report_error(ctx, err), | ||||||
|  |     }; | ||||||
|  | } | ||||||
|  |  | ||||||
| /// Aggregate is the callback interface for user-defined | /// Aggregate is the callback interface for user-defined | ||||||
| /// aggregate function. | /// aggregate function. | ||||||
| /// | /// | ||||||
| @@ -264,7 +291,7 @@ pub type SubType = Option<std::os::raw::c_uint>; | |||||||
| pub trait Aggregate<A, T> | pub trait Aggregate<A, T> | ||||||
| where | where | ||||||
|     A: RefUnwindSafe + UnwindSafe, |     A: RefUnwindSafe + UnwindSafe, | ||||||
|     T: ToSql, |     T: SqlFnOutput, | ||||||
| { | { | ||||||
|     /// Initializes the aggregation context. Will be called prior to the first |     /// Initializes the aggregation context. Will be called prior to the first | ||||||
|     /// call to [`step()`](Aggregate::step) to set up the context for an |     /// call to [`step()`](Aggregate::step) to set up the context for an | ||||||
| @@ -285,7 +312,7 @@ where | |||||||
|     /// given `None`. |     /// given `None`. | ||||||
|     /// |     /// | ||||||
|     /// The passed context will have no arguments. |     /// The passed context will have no arguments. | ||||||
|     fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<(T, SubType)>; |     fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>; | ||||||
| } | } | ||||||
|  |  | ||||||
| /// `WindowAggregate` is the callback interface for | /// `WindowAggregate` is the callback interface for | ||||||
| @@ -295,11 +322,11 @@ where | |||||||
| pub trait WindowAggregate<A, T>: Aggregate<A, T> | pub trait WindowAggregate<A, T>: Aggregate<A, T> | ||||||
| where | where | ||||||
|     A: RefUnwindSafe + UnwindSafe, |     A: RefUnwindSafe + UnwindSafe, | ||||||
|     T: ToSql, |     T: SqlFnOutput, | ||||||
| { | { | ||||||
|     /// Returns the current value of the aggregate. Unlike xFinal, the |     /// Returns the current value of the aggregate. Unlike xFinal, the | ||||||
|     /// implementation should not delete any context. |     /// implementation should not delete any context. | ||||||
|     fn value(&self, acc: Option<&mut A>) -> Result<(T, SubType)>; |     fn value(&self, acc: Option<&mut A>) -> Result<T>; | ||||||
|  |  | ||||||
|     /// Removes a row from the current window. |     /// Removes a row from the current window. | ||||||
|     fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; |     fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; | ||||||
| @@ -365,7 +392,7 @@ impl Connection { | |||||||
|     ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, |     ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, | ||||||
|     ///         |ctx| { |     ///         |ctx| { | ||||||
|     ///             let value = ctx.get::<f64>(0)?; |     ///             let value = ctx.get::<f64>(0)?; | ||||||
|     ///             Ok((value / 2f64, None)) |     ///             Ok(value / 2f64) | ||||||
|     ///         }, |     ///         }, | ||||||
|     ///     )?; |     ///     )?; | ||||||
|     /// |     /// | ||||||
| @@ -387,8 +414,8 @@ impl Connection { | |||||||
|         x_func: F, |         x_func: F, | ||||||
|     ) -> Result<()> |     ) -> Result<()> | ||||||
|     where |     where | ||||||
|         F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, |         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, | ||||||
|         T: ToSql, |         T: SqlFnOutput, | ||||||
|     { |     { | ||||||
|         self.db |         self.db | ||||||
|             .borrow_mut() |             .borrow_mut() | ||||||
| @@ -412,7 +439,7 @@ impl Connection { | |||||||
|     where |     where | ||||||
|         A: RefUnwindSafe + UnwindSafe, |         A: RefUnwindSafe + UnwindSafe, | ||||||
|         D: Aggregate<A, T> + 'static, |         D: Aggregate<A, T> + 'static, | ||||||
|         T: ToSql, |         T: SqlFnOutput, | ||||||
|     { |     { | ||||||
|         self.db |         self.db | ||||||
|             .borrow_mut() |             .borrow_mut() | ||||||
| @@ -469,16 +496,16 @@ impl InnerConnection { | |||||||
|         x_func: F, |         x_func: F, | ||||||
|     ) -> Result<()> |     ) -> Result<()> | ||||||
|     where |     where | ||||||
|         F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, |         F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static, | ||||||
|         T: ToSql, |         T: SqlFnOutput, | ||||||
|     { |     { | ||||||
|         unsafe extern "C" fn call_boxed_closure<F, T>( |         unsafe extern "C" fn call_boxed_closure<F, T>( | ||||||
|             ctx: *mut sqlite3_context, |             ctx: *mut sqlite3_context, | ||||||
|             argc: c_int, |             argc: c_int, | ||||||
|             argv: *mut *mut sqlite3_value, |             argv: *mut *mut sqlite3_value, | ||||||
|         ) where |         ) where | ||||||
|             F: FnMut(&Context<'_>) -> Result<(T, SubType)>, |             F: FnMut(&Context<'_>) -> Result<T>, | ||||||
|             T: ToSql, |             T: SqlFnOutput, | ||||||
|         { |         { | ||||||
|             let r = catch_unwind(|| { |             let r = catch_unwind(|| { | ||||||
|                 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); |                 let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); | ||||||
| @@ -496,18 +523,7 @@ impl InnerConnection { | |||||||
|                 } |                 } | ||||||
|                 Ok(r) => r, |                 Ok(r) => r, | ||||||
|             }; |             }; | ||||||
|             let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); |             sql_result(ctx, t); | ||||||
|  |  | ||||||
|             match t { |  | ||||||
|                 Ok((Ok(ref value), sub_type)) => { |  | ||||||
|                     set_result(ctx, value); |  | ||||||
|                     if let Some(sub_type) = sub_type { |  | ||||||
|                         ffi::sqlite3_result_subtype(ctx, *sub_type); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|                 Ok((Err(err), _)) => report_error(ctx, &err), |  | ||||||
|                 Err(err) => report_error(ctx, err), |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); |         let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); | ||||||
| @@ -538,7 +554,7 @@ impl InnerConnection { | |||||||
|     where |     where | ||||||
|         A: RefUnwindSafe + UnwindSafe, |         A: RefUnwindSafe + UnwindSafe, | ||||||
|         D: Aggregate<A, T> + 'static, |         D: Aggregate<A, T> + 'static, | ||||||
|         T: ToSql, |         T: SqlFnOutput, | ||||||
|     { |     { | ||||||
|         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); |         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); | ||||||
|         let c_name = str_to_cstring(fn_name)?; |         let c_name = str_to_cstring(fn_name)?; | ||||||
| @@ -624,7 +640,7 @@ unsafe extern "C" fn call_boxed_step<A, D, T>( | |||||||
| ) where | ) where | ||||||
|     A: RefUnwindSafe + UnwindSafe, |     A: RefUnwindSafe + UnwindSafe, | ||||||
|     D: Aggregate<A, T>, |     D: Aggregate<A, T>, | ||||||
|     T: ToSql, |     T: SqlFnOutput, | ||||||
| { | { | ||||||
|     let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { |     let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { | ||||||
|         pac |         pac | ||||||
| @@ -710,7 +726,7 @@ unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) | |||||||
| where | where | ||||||
|     A: RefUnwindSafe + UnwindSafe, |     A: RefUnwindSafe + UnwindSafe, | ||||||
|     D: Aggregate<A, T>, |     D: Aggregate<A, T>, | ||||||
|     T: ToSql, |     T: SqlFnOutput, | ||||||
| { | { | ||||||
|     // Within the xFinal callback, it is customary to set N=0 in calls to |     // Within the xFinal callback, it is customary to set N=0 in calls to | ||||||
|     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. |     // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. | ||||||
| @@ -744,17 +760,7 @@ where | |||||||
|         } |         } | ||||||
|         Ok(r) => r, |         Ok(r) => r, | ||||||
|     }; |     }; | ||||||
|     let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); |     sql_result(ctx, t); | ||||||
|     match t { |  | ||||||
|         Ok((Ok(ref value), sub_type)) => { |  | ||||||
|             set_result(ctx, value); |  | ||||||
|             if let Some(sub_type) = sub_type { |  | ||||||
|                 ffi::sqlite3_result_subtype(ctx, *sub_type); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         Ok((Err(err), _)) => report_error(ctx, &err), |  | ||||||
|         Err(err) => report_error(ctx, err), |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| #[cfg(feature = "window")] | #[cfg(feature = "window")] | ||||||
| @@ -806,13 +812,13 @@ mod test { | |||||||
|  |  | ||||||
|     #[cfg(feature = "window")] |     #[cfg(feature = "window")] | ||||||
|     use crate::functions::WindowAggregate; |     use crate::functions::WindowAggregate; | ||||||
|     use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; |     use crate::functions::{Aggregate, Context, FunctionFlags}; | ||||||
|     use crate::{Connection, Error, Result}; |     use crate::{Connection, Error, Result}; | ||||||
|  |  | ||||||
|     fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { |     fn half(ctx: &Context<'_>) -> Result<c_double> { | ||||||
|         assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); |         assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); | ||||||
|         let value = ctx.get::<c_double>(0)?; |         let value = ctx.get::<c_double>(0)?; | ||||||
|         Ok((value / 2f64, None)) |         Ok(value / 2f64) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -851,7 +857,7 @@ mod test { | |||||||
|     // This implementation of a regexp scalar function uses SQLite's auxiliary data |     // This implementation of a regexp scalar function uses SQLite's auxiliary data | ||||||
|     // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular |     // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular | ||||||
|     // expression multiple times within one query. |     // expression multiple times within one query. | ||||||
|     fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { |     fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> { | ||||||
|         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); |         assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); | ||||||
|         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; |         type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; | ||||||
|         let regexp: std::sync::Arc<Regex> = ctx |         let regexp: std::sync::Arc<Regex> = ctx | ||||||
| @@ -868,7 +874,7 @@ mod test { | |||||||
|             regexp.is_match(text) |             regexp.is_match(text) | ||||||
|         }; |         }; | ||||||
|  |  | ||||||
|         Ok((is_match, None)) |         Ok(is_match) | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #[test] |     #[test] | ||||||
| @@ -915,7 +921,7 @@ mod test { | |||||||
|                     ret.push_str(&s); |                     ret.push_str(&s); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 Ok((ret, None)) |                 Ok(ret) | ||||||
|             }, |             }, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
| @@ -940,7 +946,7 @@ mod test { | |||||||
|                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); |                 assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); | ||||||
|                 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); |                 assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); | ||||||
|             } |             } | ||||||
|             Ok((true, None)) |             Ok(true) | ||||||
|         })?; |         })?; | ||||||
|  |  | ||||||
|         let res: bool = |         let res: bool = | ||||||
| @@ -963,12 +969,8 @@ mod test { | |||||||
|             Ok(()) |             Ok(()) | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         fn finalize( |         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> { | ||||||
|             &self, |             Ok(sum) | ||||||
|             _: &mut Context<'_>, |  | ||||||
|             sum: Option<i64>, |  | ||||||
|         ) -> Result<(Option<i64>, SubType)> { |  | ||||||
|             Ok((sum, None)) |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -982,8 +984,8 @@ mod test { | |||||||
|             Ok(()) |             Ok(()) | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<(i64, SubType)> { |         fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> { | ||||||
|             Ok((sum.unwrap_or(0), None)) |             Ok(sum.unwrap_or(0)) | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -1041,8 +1043,8 @@ mod test { | |||||||
|             Ok(()) |             Ok(()) | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         fn value(&self, sum: Option<&mut i64>) -> Result<(Option<i64>, SubType)> { |         fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> { | ||||||
|             Ok((sum.copied(), None)) |             Ok(sum.copied()) | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1801,7 +1801,7 @@ mod test { | |||||||
|             functions::FunctionFlags::default(), |             functions::FunctionFlags::default(), | ||||||
|             move |_| { |             move |_| { | ||||||
|                 interrupt_handle.interrupt(); |                 interrupt_handle.interrupt(); | ||||||
|                 Ok((0, None)) |                 Ok(0) | ||||||
|             }, |             }, | ||||||
|         )?; |         )?; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user