From ee48859794487579416c25c1a50af684e533d652 Mon Sep 17 00:00:00 2001 From: gwenn Date: Thu, 7 Apr 2022 10:43:16 +0200 Subject: [PATCH 1/6] Make possible to specify subtype of SQL function --- src/functions.rs | 90 +++++++++++++++++++++++++++++------------------- src/lib.rs | 2 +- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 6de9612..5a85752 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -34,7 +34,7 @@ //! regexp.is_match(text) //! }; //! -//! Ok(is_match) +//! Ok((is_match, None)) //! }, //! ) //! } @@ -247,13 +247,6 @@ impl Context<'_> { phantom: PhantomData, }) } - - /// Set the Subtype of an SQL function - #[cfg(feature = "modern_sqlite")] // 3.9.0 - #[cfg_attr(docsrs, doc(cfg(feature = "modern_sqlite")))] - pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) { - unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) }; - } } /// A reference to a connection handle with a lifetime bound to something. @@ -275,6 +268,9 @@ impl Deref for ConnectionRef<'_> { type AuxInner = Arc; +/// Subtype of an SQL function +pub type SubType = Option; + /// Aggregate is the callback interface for user-defined /// aggregate function. /// @@ -304,7 +300,7 @@ where /// given `None`. /// /// The passed context will have no arguments. - fn finalize(&self, _: &mut Context<'_>, _: Option) -> Result; + fn finalize(&self, _: &mut Context<'_>, _: Option) -> Result<(T, SubType)>; } /// `WindowAggregate` is the callback interface for @@ -318,7 +314,7 @@ where { /// Returns the current value of the aggregate. Unlike xFinal, the /// implementation should not delete any context. - fn value(&self, _: Option<&A>) -> Result; + fn value(&self, _: Option<&A>) -> Result<(T, SubType)>; /// Removes a row from the current window. fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; @@ -381,7 +377,7 @@ impl Connection { /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, /// |ctx| { /// let value = ctx.get::(0)?; - /// Ok(value / 2f64) + /// Ok((value / 2f64, None)) /// }, /// )?; /// @@ -403,7 +399,7 @@ impl Connection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, T: ToSql, { self.db @@ -485,7 +481,7 @@ impl InnerConnection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, T: ToSql, { unsafe extern "C" fn call_boxed_closure( @@ -493,7 +489,7 @@ impl InnerConnection { argc: c_int, argv: *mut *mut sqlite3_value, ) where - F: FnMut(&Context<'_>) -> Result, + F: FnMut(&Context<'_>) -> Result<(T, SubType)>, T: ToSql, { let r = catch_unwind(|| { @@ -512,11 +508,17 @@ impl InnerConnection { } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); + let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), + Ok((Ok(ref value), sub_type)) => { + set_result(ctx, value); + #[cfg(feature = "modern_sqlite")] // 3.9.0 + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, *sub_type); + } + } + Ok((Err(err), _)) => report_error(ctx, &err), Err(err) => report_error(ctx, err), } } @@ -752,10 +754,16 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); + let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), + Ok((Ok(ref value), sub_type)) => { + set_result(ctx, value); + #[cfg(feature = "modern_sqlite")] // 3.9.0 + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, *sub_type); + } + } + Ok((Err(err), _)) => report_error(ctx, &err), Err(err) => report_error(ctx, err), } } @@ -796,10 +804,16 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); + let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), + Ok((Ok(ref value), sub_type)) => { + set_result(ctx, value); + #[cfg(feature = "modern_sqlite")] // 3.9.0 + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, *sub_type); + } + } + Ok((Err(err), _)) => report_error(ctx, &err), Err(err) => report_error(ctx, err), } } @@ -812,13 +826,13 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags}; + use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; use crate::{Connection, Error, Result}; - fn half(ctx: &Context<'_>) -> Result { + fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); let value = ctx.get::(0)?; - Ok(value / 2f64) + Ok((value / 2f64, None)) } #[test] @@ -857,7 +871,7 @@ mod test { // This implementation of a regexp scalar function uses SQLite's auxiliary data // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // expression multiple times within one query. - fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result { + fn regexp_with_auxilliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); type BoxError = Box; let regexp: std::sync::Arc = ctx @@ -874,7 +888,7 @@ mod test { regexp.is_match(text) }; - Ok(is_match) + Ok((is_match, None)) } #[test] @@ -925,7 +939,7 @@ mod test { ret.push_str(&s); } - Ok(ret) + Ok((ret, None)) }, )?; @@ -950,7 +964,7 @@ mod test { assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); assert_eq!(*ctx.get_aux::(0)?.unwrap(), 100); } - Ok(true) + Ok((true, None)) })?; let res: bool = db.query_row( @@ -976,8 +990,12 @@ mod test { Ok(()) } - fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result> { - Ok(sum) + fn finalize( + &self, + _: &mut Context<'_>, + sum: Option, + ) -> Result<(Option, SubType)> { + Ok((sum, None)) } } @@ -991,8 +1009,8 @@ mod test { Ok(()) } - fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result { - Ok(sum.unwrap_or(0)) + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result<(i64, SubType)> { + Ok((sum.unwrap_or(0), None)) } } @@ -1050,8 +1068,8 @@ mod test { Ok(()) } - fn value(&self, sum: Option<&i64>) -> Result> { - Ok(sum.copied()) + fn value(&self, sum: Option<&i64>) -> Result<(Option, SubType)> { + Ok((sum.copied(), None)) } } diff --git a/src/lib.rs b/src/lib.rs index 070b1a7..e1ea09f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1733,7 +1733,7 @@ mod test { crate::functions::FunctionFlags::default(), move |_| { interrupt_handle.interrupt(); - Ok(0) + Ok((0, None)) }, )?; From 109e6faa94e2f5b1556c054bd932dd461e32a56d Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 30 Oct 2022 09:12:42 +0100 Subject: [PATCH 2/6] `modern_sqlite` not needed anymore --- src/functions.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index cc11a62..6c34445 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -506,7 +506,6 @@ impl InnerConnection { match t { Ok((Ok(ref value), sub_type)) => { set_result(ctx, value); - #[cfg(feature = "modern_sqlite")] // 3.9.0 if let Some(sub_type) = sub_type { ffi::sqlite3_result_subtype(ctx, *sub_type); } @@ -751,7 +750,6 @@ where match t { Ok((Ok(ref value), sub_type)) => { set_result(ctx, value); - #[cfg(feature = "modern_sqlite")] // 3.9.0 if let Some(sub_type) = sub_type { ffi::sqlite3_result_subtype(ctx, *sub_type); } @@ -801,7 +799,6 @@ where match t { Ok((Ok(ref value), sub_type)) => { set_result(ctx, value); - #[cfg(feature = "modern_sqlite")] // 3.9.0 if let Some(sub_type) = sub_type { ffi::sqlite3_result_subtype(ctx, *sub_type); } From 7ed8e0ef2f47cbd1aae5d1974a2f6ebca9ced556 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sat, 20 Jan 2024 17:08:15 +0100 Subject: [PATCH 3/6] Introduce SqlFnOutput trait To keep compatibility with existing code --- src/functions.rs | 116 ++++++++++++++++++++++++----------------------- src/lib.rs | 2 +- 2 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index f599e63..ecc423c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -34,7 +34,7 @@ //! regexp.is_match(text) //! }; //! -//! Ok((is_match, None)) +//! Ok(is_match) //! }, //! ) //! } @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context; use crate::ffi::sqlite3_value; use crate::context::set_result; -use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; @@ -256,6 +256,33 @@ type AuxInner = Arc; /// Subtype of an SQL function pub type SubType = Option; +/// Result of an SQL function +pub trait SqlFnOutput { + /// Converts Rust value to SQLite value with an optional sub-type + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>; +} + +impl SqlFnOutput for T { + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { + ToSql::to_sql(self).map(|o| (o, None)) + } +} + +unsafe fn sql_result(ctx: *mut sqlite3_context, r: Result) { + let t = r.as_ref().map(SqlFnOutput::to_sql); + + match t { + Ok(Ok((ref value, sub_type))) => { + set_result(ctx, value); + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, sub_type); + } + } + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + }; +} + /// Aggregate is the callback interface for user-defined /// aggregate function. /// @@ -264,7 +291,7 @@ pub type SubType = Option; pub trait Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Initializes the aggregation context. Will be called prior to the first /// call to [`step()`](Aggregate::step) to set up the context for an @@ -285,7 +312,7 @@ where /// given `None`. /// /// The passed context will have no arguments. - fn finalize(&self, ctx: &mut Context<'_>, acc: Option) -> Result<(T, SubType)>; + fn finalize(&self, ctx: &mut Context<'_>, acc: Option) -> Result; } /// `WindowAggregate` is the callback interface for @@ -295,11 +322,11 @@ where pub trait WindowAggregate: Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Returns the current value of the aggregate. Unlike xFinal, the /// implementation should not delete any context. - fn value(&self, acc: Option<&mut A>) -> Result<(T, SubType)>; + fn value(&self, acc: Option<&mut A>) -> Result; /// Removes a row from the current window. fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; @@ -365,7 +392,7 @@ impl Connection { /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, /// |ctx| { /// let value = ctx.get::(0)?; - /// Ok((value / 2f64, None)) + /// Ok(value / 2f64) /// }, /// )?; /// @@ -387,8 +414,8 @@ impl Connection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, - T: ToSql, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: SqlFnOutput, { self.db .borrow_mut() @@ -412,7 +439,7 @@ impl Connection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -469,16 +496,16 @@ impl InnerConnection { x_func: F, ) -> Result<()> where - F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, - T: ToSql, + F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, + T: SqlFnOutput, { unsafe extern "C" fn call_boxed_closure( ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value, ) where - F: FnMut(&Context<'_>) -> Result<(T, SubType)>, - T: ToSql, + F: FnMut(&Context<'_>) -> Result, + T: SqlFnOutput, { let r = catch_unwind(|| { let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); @@ -496,18 +523,7 @@ impl InnerConnection { } Ok(r) => r, }; - let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); - - match t { - Ok((Ok(ref value), sub_type)) => { - set_result(ctx, value); - if let Some(sub_type) = sub_type { - ffi::sqlite3_result_subtype(ctx, *sub_type); - } - } - Ok((Err(err), _)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, t); } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); @@ -538,7 +554,7 @@ impl InnerConnection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; @@ -624,7 +640,7 @@ unsafe extern "C" fn call_boxed_step( ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { pac @@ -710,7 +726,7 @@ unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { // Within the xFinal callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. @@ -744,17 +760,7 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); - match t { - Ok((Ok(ref value), sub_type)) => { - set_result(ctx, value); - if let Some(sub_type) = sub_type { - ffi::sqlite3_result_subtype(ctx, *sub_type); - } - } - Ok((Err(err), _)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, t); } #[cfg(feature = "window")] @@ -806,13 +812,13 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; + use crate::functions::{Aggregate, Context, FunctionFlags}; use crate::{Connection, Error, Result}; - fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { + fn half(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); let value = ctx.get::(0)?; - Ok((value / 2f64, None)) + Ok(value / 2f64) } #[test] @@ -851,7 +857,7 @@ mod test { // This implementation of a regexp scalar function uses SQLite's auxiliary data // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // expression multiple times within one query. - fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { + fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); type BoxError = Box; let regexp: std::sync::Arc = ctx @@ -868,7 +874,7 @@ mod test { regexp.is_match(text) }; - Ok((is_match, None)) + Ok(is_match) } #[test] @@ -915,7 +921,7 @@ mod test { ret.push_str(&s); } - Ok((ret, None)) + Ok(ret) }, )?; @@ -940,7 +946,7 @@ mod test { assert_eq!(ctx.get_aux::(0), Err(Error::GetAuxWrongType)); assert_eq!(*ctx.get_aux::(0)?.unwrap(), 100); } - Ok((true, None)) + Ok(true) })?; let res: bool = @@ -963,12 +969,8 @@ mod test { Ok(()) } - fn finalize( - &self, - _: &mut Context<'_>, - sum: Option, - ) -> Result<(Option, SubType)> { - Ok((sum, None)) + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result> { + Ok(sum) } } @@ -982,8 +984,8 @@ mod test { Ok(()) } - fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result<(i64, SubType)> { - Ok((sum.unwrap_or(0), None)) + fn finalize(&self, _: &mut Context<'_>, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) } } @@ -1041,8 +1043,8 @@ mod test { Ok(()) } - fn value(&self, sum: Option<&mut i64>) -> Result<(Option, SubType)> { - Ok((sum.copied(), None)) + fn value(&self, sum: Option<&mut i64>) -> Result> { + Ok(sum.copied()) } } diff --git a/src/lib.rs b/src/lib.rs index f86ed67..1448707 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1801,7 +1801,7 @@ mod test { functions::FunctionFlags::default(), move |_| { interrupt_handle.interrupt(); - Ok((0, None)) + Ok(0) }, )?; From 13399c580868c9ba6a420e6b9ad8037dad3b1cce Mon Sep 17 00:00:00 2001 From: gwenn Date: Sat, 20 Jan 2024 17:22:58 +0100 Subject: [PATCH 4/6] Fix window impl --- src/functions.rs | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index ecc423c..00ac69f 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -464,7 +464,7 @@ impl Connection { where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -585,7 +585,7 @@ impl InnerConnection { where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate + 'static, - T: ToSql, + T: SqlFnOutput, { let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; @@ -688,7 +688,7 @@ unsafe extern "C" fn call_boxed_inverse( ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate, - T: ToSql, + T: SqlFnOutput, { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { pac @@ -768,7 +768,7 @@ unsafe extern "C" fn call_boxed_value(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate, - T: ToSql, + T: SqlFnOutput, { // Within the xValue callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. @@ -792,17 +792,7 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); - match t { - Ok((Ok(ref value), sub_type)) => { - set_result(ctx, value); - if let Some(sub_type) = sub_type { - ffi::sqlite3_result_subtype(ctx, *sub_type); - } - } - Ok((Err(err), _)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, t); } #[cfg(test)] From 83d67d5a29289b15bfda977b0fd532bb53829e94 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sat, 20 Jan 2024 19:16:48 +0100 Subject: [PATCH 5/6] Test sub-type --- src/functions.rs | 44 ++++++++++++++++++++++++++++++++++++++++++-- src/types/to_sql.rs | 7 +++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 00ac69f..0b8c2b3 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -263,11 +263,18 @@ pub trait SqlFnOutput { } impl SqlFnOutput for T { + #[inline] fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { ToSql::to_sql(self).map(|o| (o, None)) } } +impl SqlFnOutput for (T, SubType) { + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { + ToSql::to_sql(&self.0).map(|o| (o, self.1)) + } +} + unsafe fn sql_result(ctx: *mut sqlite3_context, r: Result) { let t = r.as_ref().map(SqlFnOutput::to_sql); @@ -802,8 +809,8 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags}; - use crate::{Connection, Error, Result}; + use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; + use crate::{Connection, Error, Result, ValueRef}; fn half(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); @@ -1080,4 +1087,37 @@ mod test { assert_eq!(expected, results); Ok(()) } + + #[test] + fn test_sub_type() -> Result<()> { + fn test_getsubtype(ctx: &Context<'_>) -> Result { + Ok(ctx.get_subtype(0) as i32) + } + fn test_setsubtype<'a>(ctx: &'a Context<'_>) -> Result<(ValueRef<'a>, SubType)> { + use std::os::raw::c_uint; + let value = ctx.get_raw(0); + let sub_type = ctx.get::(1)?; + Ok((value, Some(sub_type))) + } + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "test_getsubtype", + 1, + FunctionFlags::SQLITE_UTF8, + test_getsubtype, + )?; + db.create_scalar_function( + "test_setsubtype", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE, + test_setsubtype, + )?; + let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?; + assert_eq!(0, result); + + let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?; + assert_eq!(123, result); + + Ok(()) + } } diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs index 29e63ab..855d339 100644 --- a/src/types/to_sql.rs +++ b/src/types/to_sql.rs @@ -292,6 +292,13 @@ impl ToSql for Value { } } +impl<'a> ToSql for ValueRef<'a> { + #[inline] + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Borrowed(*self)) + } +} + impl ToSql for Option { #[inline] fn to_sql(&self) -> Result> { From f48c5781a19e7c9024a8d64fe1fe73ea20f2cc44 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 21 Jan 2024 10:13:07 +0100 Subject: [PATCH 6/6] Introduce SqlFnArg --- src/context.rs | 11 ++++++++++- src/functions.rs | 47 ++++++++++++++++++++++++++++++++------------- src/pragma.rs | 7 +++++++ src/statement.rs | 7 +++++++ src/types/to_sql.rs | 14 +++++++------- src/vtab/mod.rs | 2 +- 6 files changed, 66 insertions(+), 22 deletions(-) diff --git a/src/context.rs b/src/context.rs index bcaefc9..6a8bb79 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,5 +1,6 @@ //! Code related to `sqlite3_context` common to `functions` and `vtab` modules. +use libsqlite3_sys::sqlite3_value; use std::os::raw::{c_int, c_void}; #[cfg(feature = "array")] use std::rc::Rc; @@ -16,7 +17,11 @@ use crate::vtab::array::{free_array, ARRAY_TYPE}; // is often known to the compiler, and thus const prop/DCE can substantially // simplify the function. #[inline] -pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { +pub(super) unsafe fn set_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + result: &ToSqlOutput<'_>, +) { let value = match *result { ToSqlOutput::Borrowed(v) => v, ToSqlOutput::Owned(ref v) => ValueRef::from(v), @@ -26,6 +31,10 @@ pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput< // TODO sqlite3_result_zeroblob64 // 3.8.11 return ffi::sqlite3_result_zeroblob(ctx, len); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => { + return ffi::sqlite3_result_value(ctx, args[i]); + } #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => { return ffi::sqlite3_result_pointer( diff --git a/src/functions.rs b/src/functions.rs index 0b8c2b3..f4a508c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -149,6 +149,15 @@ impl Context<'_> { unsafe { ValueRef::from_value(arg) } } + /// Returns the `idx`th argument as a `SqlFnArg`. + /// To be used when the SQL function result is one of its arguments. + #[inline] + #[must_use] + pub fn get_arg(&self, idx: usize) -> SqlFnArg { + assert!(idx < self.len()); + SqlFnArg { idx } + } + /// Returns the subtype of `idx`th argument. /// /// # Failure @@ -275,12 +284,26 @@ impl SqlFnOutput for (T, SubType) { } } -unsafe fn sql_result(ctx: *mut sqlite3_context, r: Result) { +/// n-th arg of an SQL scalar function +pub struct SqlFnArg { + idx: usize, +} +impl ToSql for SqlFnArg { + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Arg(self.idx)) + } +} + +unsafe fn sql_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + r: Result, +) { let t = r.as_ref().map(SqlFnOutput::to_sql); match t { Ok(Ok((ref value, sub_type))) => { - set_result(ctx, value); + set_result(ctx, args, value); if let Some(sub_type) = sub_type { ffi::sqlite3_result_subtype(ctx, sub_type); } @@ -514,13 +537,11 @@ impl InnerConnection { F: FnMut(&Context<'_>) -> Result, T: SqlFnOutput, { + let args = slice::from_raw_parts(argv, argc as usize); let r = catch_unwind(|| { let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - let ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), - }; + let ctx = Context { ctx, args }; (*boxed_f)(&ctx) }); let t = match r { @@ -530,7 +551,7 @@ impl InnerConnection { } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, args, t); } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); @@ -767,7 +788,7 @@ where } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, &[], t); } #[cfg(feature = "window")] @@ -799,7 +820,7 @@ where } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, &[], t); } #[cfg(test)] @@ -809,8 +830,8 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; - use crate::{Connection, Error, Result, ValueRef}; + use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType}; + use crate::{Connection, Error, Result}; fn half(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); @@ -1093,9 +1114,9 @@ mod test { fn test_getsubtype(ctx: &Context<'_>) -> Result { Ok(ctx.get_subtype(0) as i32) } - fn test_setsubtype<'a>(ctx: &'a Context<'_>) -> Result<(ValueRef<'a>, SubType)> { + fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> { use std::os::raw::c_uint; - let value = ctx.get_raw(0); + let value = ctx.get_arg(0); let sub_type = ctx.get::(1)?; Ok((value, Some(sub_type))) } diff --git a/src/pragma.rs b/src/pragma.rs index 46bbde1..f1c5049 100644 --- a/src/pragma.rs +++ b/src/pragma.rs @@ -70,6 +70,13 @@ impl Sql { Some(format!("Unsupported value \"{value:?}\"")), )); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(_) => { return Err(Error::SqliteFailure( diff --git a/src/statement.rs b/src/statement.rs index edc1870..ac8f1db 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -606,6 +606,13 @@ impl Statement<'_> { .conn .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(a) => { return self.conn.decode_result(unsafe { diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs index 855d339..a5cb1ff 100644 --- a/src/types/to_sql.rs +++ b/src/types/to_sql.rs @@ -22,6 +22,11 @@ pub enum ToSqlOutput<'a> { #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] ZeroBlob(i32), + /// n-th arg of an SQL scalar function + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + Arg(usize), + /// `feature = "array"` #[cfg(feature = "array")] #[cfg_attr(docsrs, doc(cfg(feature = "array")))] @@ -107,6 +112,8 @@ impl ToSql for ToSqlOutput<'_> { #[cfg(feature = "blob")] ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => ToSqlOutput::Arg(i), #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), }) @@ -292,13 +299,6 @@ impl ToSql for Value { } } -impl<'a> ToSql for ValueRef<'a> { - #[inline] - fn to_sql(&self) -> Result> { - Ok(ToSqlOutput::Borrowed(*self)) - } -} - impl ToSql for Option { #[inline] fn to_sql(&self) -> Result> { diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index 4ee1693..333a3bf 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -704,7 +704,7 @@ impl Context { #[inline] pub fn set_result(&mut self, value: &T) -> Result<()> { let t = value.to_sql()?; - unsafe { set_result(self.0, &t) }; + unsafe { set_result(self.0, &[], &t) }; Ok(()) }