From 7ed8e0ef2f47cbd1aae5d1974a2f6ebca9ced556 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sat, 20 Jan 2024 17:08:15 +0100 Subject: [PATCH] Introduce SqlFnOutput trait To keep compatibility with existing code --- src/functions.rs | 116 ++++++++++++++++++++++++----------------------- src/lib.rs | 2 +- 2 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index f599e63..ecc423c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -34,7 +34,7 @@ //! regexp.is_match(text) //! }; //! -//! Ok((is_match, None)) +//! Ok(is_match) //! }, //! ) //! } @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context; use crate::ffi::sqlite3_value; use crate::context::set_result; -use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; @@ -256,6 +256,33 @@ type AuxInner = Arc; /// Subtype of an SQL function pub type SubType = Option; +/// Result of an SQL function +pub trait SqlFnOutput { + /// Converts Rust value to SQLite value with an optional sub-type + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>; +} + +impl SqlFnOutput for T { + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { + ToSql::to_sql(self).map(|o| (o, None)) + } +} + +unsafe fn sql_result(ctx: *mut sqlite3_context, r: Result) { + let t = r.as_ref().map(SqlFnOutput::to_sql); + + match t { + Ok(Ok((ref value, sub_type))) => { + set_result(ctx, value); + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, sub_type); + } + } + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + }; +} + /// Aggregate is the callback interface for user-defined /// aggregate function. /// @@ -264,7 +291,7 @@ pub type SubType = Option; pub trait Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Initializes the aggregation context. Will be called prior to the first /// call to [`step()`](Aggregate::step) to set up the context for an @@ -285,7 +312,7 @@ where /// given `None`. /// /// The passed context will have no arguments. - fn finalize(&self, ctx: &mut Context<'_>, acc: Option) -> Result<(T, SubType)>; + fn finalize(&self, ctx: &mut Context<'_>, acc: Option) -> Result; } /// `WindowAggregate` is the callback interface for @@ -295,11 +322,11 @@ where pub trait WindowAggregate: Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Returns the current value of the aggregate. Unlike xFinal, the /// implementation should not delete any context. - fn value(&self, acc: Option<&mut A>) -> Result<(T, SubType)>; + fn value(&self, acc: Option<&mut A>) -> Result; /// Removes a row from the current window. fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; @@ -365,7 +392,7 @@ impl Connection { /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, /// |ctx| { /// let value = ctx.get::(0)?; - /// Ok((value / 2f64, None)) + /// Ok(value / 2f64) /// }, /// )?; /// @@ -387,8 +414,8 @@ impl Connection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, - T: ToSql, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: SqlFnOutput, { self.db .borrow_mut() @@ -412,7 +439,7 @@ impl Connection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -469,16 +496,16 @@ impl InnerConnection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, - T: ToSql, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: SqlFnOutput, { unsafe extern "C" fn call_boxed_closure( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where - F: FnMut(&Context<'_>) -> Result<(T, SubType)>, - T: ToSql, + F: FnMut(&Context<'_>) -> Result, + T: SqlFnOutput, { let r = catch_unwind(|| { let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); @@ -496,18 +523,7 @@ impl InnerConnection { } Ok(r) => r, }; - let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); - - match t { - Ok((Ok(ref value), sub_type)) => { - set_result(ctx, value); - if let Some(sub_type) = sub_type { - ffi::sqlite3_result_subtype(ctx, *sub_type); - } - } - Ok((Err(err), _)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, t); } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); @@ -538,7 +554,7 @@ impl InnerConnection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; @@ -624,7 +640,7 @@ unsafe extern "C" fn call_boxed_step( ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { pac @@ -710,7 +726,7 @@ unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { // Within the xFinal callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. @@ -744,17 +760,7 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); - match t { - Ok((Ok(ref value), sub_type)) => { - set_result(ctx, value); - if let Some(sub_type) = sub_type { - ffi::sqlite3_result_subtype(ctx, *sub_type); - } - } - Ok((Err(err), _)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, t); } #[cfg(feature = "window")] @@ -806,13 +812,13 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; + use crate::functions::{Aggregate, Context, FunctionFlags}; use crate::{Connection, Error, Result}; - fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { + fn half(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); let value = ctx.get::(0)?; - Ok((value / 2f64, None)) + Ok(value / 2f64) } #[test] @@ -851,7 +857,7 @@ mod test { // This implementation of a regexp scalar function uses SQLite's auxiliary data // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // expression multiple times within one query. - fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { + fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); type BoxError = Box; let regexp: std::sync::Arc = ctx @@ -868,7 +874,7 @@ mod test { regexp.is_match(text) }; - Ok((is_match, None)) + Ok(is_match) } #[test] @@ -915,7 +921,7 @@ mod test { ret.push_str(&s); } - Ok((ret, None)) + Ok(ret) }, )?; @@ -940,7 +946,7 @@ mod test { assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); assert_eq!(*ctx.get_aux::(0)?.unwrap(), 100); } - Ok((true, None)) + Ok(true) })?; let res: bool = @@ -963,12 +969,8 @@ mod test { Ok(()) } - fn finalize( - &self, - _: &mut Context<'_>, - sum: Option, - ) -> Result<(Option, SubType)> { - Ok((sum, None)) + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result> { + Ok(sum) } } @@ -982,8 +984,8 @@ mod test { Ok(()) } - fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result<(i64, SubType)> { - Ok((sum.unwrap_or(0), None)) + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) } } @@ -1041,8 +1043,8 @@ mod test { Ok(()) } - fn value(&self, sum: Option<&mut i64>) -> Result<(Option, SubType)> { - Ok((sum.copied(), None)) + fn value(&self, sum: Option<&mut i64>) -> Result> { + Ok(sum.copied()) } } diff --git a/src/lib.rs b/src/lib.rs index f86ed67..1448707 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1801,7 +1801,7 @@ mod test { functions::FunctionFlags::default(), move |_| { interrupt_handle.interrupt(); - Ok((0, None)) + Ok(0) }, )?;