diff --git a/src/context.rs b/src/context.rs index bcaefc9..6a8bb79 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,5 +1,6 @@ //! Code related to `sqlite3_context` common to `functions` and `vtab` modules. +use libsqlite3_sys::sqlite3_value; use std::os::raw::{c_int, c_void}; #[cfg(feature = "array")] use std::rc::Rc; @@ -16,7 +17,11 @@ use crate::vtab::array::{free_array, ARRAY_TYPE}; // is often known to the compiler, and thus const prop/DCE can substantially // simplify the function. #[inline] -pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { +pub(super) unsafe fn set_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + result: &ToSqlOutput<'_>, +) { let value = match *result { ToSqlOutput::Borrowed(v) => v, ToSqlOutput::Owned(ref v) => ValueRef::from(v), @@ -26,6 +31,10 @@ pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput< // TODO sqlite3_result_zeroblob64 // 3.8.11 return ffi::sqlite3_result_zeroblob(ctx, len); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => { + return ffi::sqlite3_result_value(ctx, args[i]); + } #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => { return ffi::sqlite3_result_pointer( diff --git a/src/functions.rs b/src/functions.rs index 576d1b0..f4a508c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context; use crate::ffi::sqlite3_value; use crate::context::set_result; -use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; +use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef}; use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; @@ -149,6 +149,15 @@ impl Context<'_> { unsafe { ValueRef::from_value(arg) } } + /// Returns the `idx`th argument as a `SqlFnArg`. + /// To be used when the SQL function result is one of its arguments. + #[inline] + #[must_use] + pub fn get_arg(&self, idx: usize) -> SqlFnArg { + assert!(idx < self.len()); + SqlFnArg { idx } + } + /// Returns the subtype of `idx`th argument. /// /// # Failure @@ -232,11 +241,6 @@ impl Context<'_> { phantom: PhantomData, }) } - - /// Set the Subtype of an SQL function - pub fn set_result_subtype(&self, sub_type: std::os::raw::c_uint) { - unsafe { ffi::sqlite3_result_subtype(self.ctx, sub_type) }; - } } /// A reference to a connection handle with a lifetime bound to something. @@ -258,6 +262,57 @@ impl Deref for ConnectionRef<'_> { type AuxInner = Arc; +/// Subtype of an SQL function +pub type SubType = Option; + +/// Result of an SQL function +pub trait SqlFnOutput { + /// Converts Rust value to SQLite value with an optional sub-type + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>; +} + +impl SqlFnOutput for T { + #[inline] + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { + ToSql::to_sql(self).map(|o| (o, None)) + } +} + +impl SqlFnOutput for (T, SubType) { + fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> { + ToSql::to_sql(&self.0).map(|o| (o, self.1)) + } +} + +/// n-th arg of an SQL scalar function +pub struct SqlFnArg { + idx: usize, +} +impl ToSql for SqlFnArg { + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Arg(self.idx)) + } +} + +unsafe fn sql_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + r: Result, +) { + let t = r.as_ref().map(SqlFnOutput::to_sql); + + match t { + Ok(Ok((ref value, sub_type))) => { + set_result(ctx, args, value); + if let Some(sub_type) = sub_type { + ffi::sqlite3_result_subtype(ctx, sub_type); + } + } + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + }; +} + /// Aggregate is the callback interface for user-defined /// aggregate function. /// @@ -266,7 +321,7 @@ type AuxInner = Arc; pub trait Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Initializes the aggregation context. Will be called prior to the first /// call to [`step()`](Aggregate::step) to set up the context for an @@ -297,7 +352,7 @@ where pub trait WindowAggregate: Aggregate where A: RefUnwindSafe + UnwindSafe, - T: ToSql, + T: SqlFnOutput, { /// Returns the current value of the aggregate. Unlike xFinal, the /// implementation should not delete any context. @@ -330,6 +385,8 @@ bitflags::bitflags! { const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0 /// Means that the function is unlikely to cause problems even if misused. const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0 + /// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a sub-type to be associated with its result. + const SQLITE_RESULT_SUBTYPE = 0x0000_0100_0000; // 3.45.0 } } @@ -388,7 +445,7 @@ impl Connection { ) -> Result<()> where F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -412,7 +469,7 @@ impl Connection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -437,7 +494,7 @@ impl Connection { where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate + 'static, - T: ToSql, + T: SqlFnOutput, { self.db .borrow_mut() @@ -470,7 +527,7 @@ impl InnerConnection { ) -> Result<()> where F: FnMut(&Context<'_>) -> Result + Send + UnwindSafe + 'static, - T: ToSql, + T: SqlFnOutput, { unsafe extern "C" fn call_boxed_closure( ctx: *mut sqlite3_context, @@ -478,15 +535,13 @@ impl InnerConnection { argv: *mut *mut sqlite3_value, ) where F: FnMut(&Context<'_>) -> Result, - T: ToSql, + T: SqlFnOutput, { + let args = slice::from_raw_parts(argv, argc as usize); let r = catch_unwind(|| { let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - let ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), - }; + let ctx = Context { ctx, args }; (*boxed_f)(&ctx) }); let t = match r { @@ -496,13 +551,7 @@ impl InnerConnection { } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); - - match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, args, t); } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); @@ -533,7 +582,7 @@ impl InnerConnection { where A: RefUnwindSafe + UnwindSafe, D: Aggregate + 'static, - T: ToSql, + T: SqlFnOutput, { let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; @@ -564,7 +613,7 @@ impl InnerConnection { where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate + 'static, - T: ToSql, + T: SqlFnOutput, { let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; @@ -619,7 +668,7 @@ unsafe extern "C" fn call_boxed_step( ) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { pac @@ -667,7 +716,7 @@ unsafe extern "C" fn call_boxed_inverse( ) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate, - T: ToSql, + T: SqlFnOutput, { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { pac @@ -705,7 +754,7 @@ unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, D: Aggregate, - T: ToSql, + T: SqlFnOutput, { // Within the xFinal callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. @@ -739,12 +788,7 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); - match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, &[], t); } #[cfg(feature = "window")] @@ -752,7 +796,7 @@ unsafe extern "C" fn call_boxed_value(ctx: *mut sqlite3_context) where A: RefUnwindSafe + UnwindSafe, W: WindowAggregate, - T: ToSql, + T: SqlFnOutput, { // Within the xValue callback, it is customary to set N=0 in calls to // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. @@ -776,12 +820,7 @@ where } Ok(r) => r, }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); - match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } + sql_result(ctx, &[], t); } #[cfg(test)] @@ -791,7 +830,7 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags}; + use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType}; use crate::{Connection, Error, Result}; fn half(ctx: &Context<'_>) -> Result { @@ -1069,4 +1108,37 @@ mod test { assert_eq!(expected, results); Ok(()) } + + #[test] + fn test_sub_type() -> Result<()> { + fn test_getsubtype(ctx: &Context<'_>) -> Result { + Ok(ctx.get_subtype(0) as i32) + } + fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> { + use std::os::raw::c_uint; + let value = ctx.get_arg(0); + let sub_type = ctx.get::(1)?; + Ok((value, Some(sub_type))) + } + let db = Connection::open_in_memory()?; + db.create_scalar_function( + "test_getsubtype", + 1, + FunctionFlags::SQLITE_UTF8, + test_getsubtype, + )?; + db.create_scalar_function( + "test_setsubtype", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE, + test_setsubtype, + )?; + let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?; + assert_eq!(0, result); + + let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?; + assert_eq!(123, result); + + Ok(()) + } } diff --git a/src/pragma.rs b/src/pragma.rs index 46bbde1..f1c5049 100644 --- a/src/pragma.rs +++ b/src/pragma.rs @@ -70,6 +70,13 @@ impl Sql { Some(format!("Unsupported value \"{value:?}\"")), )); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(_) => { return Err(Error::SqliteFailure( diff --git a/src/statement.rs b/src/statement.rs index 3ee9f1c..0a0c116 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -606,6 +606,13 @@ impl Statement<'_> { .conn .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(a) => { return self.conn.decode_result(unsafe { diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs index 29e63ab..a5cb1ff 100644 --- a/src/types/to_sql.rs +++ b/src/types/to_sql.rs @@ -22,6 +22,11 @@ pub enum ToSqlOutput<'a> { #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] ZeroBlob(i32), + /// n-th arg of an SQL scalar function + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + Arg(usize), + /// `feature = "array"` #[cfg(feature = "array")] #[cfg_attr(docsrs, doc(cfg(feature = "array")))] @@ -107,6 +112,8 @@ impl ToSql for ToSqlOutput<'_> { #[cfg(feature = "blob")] ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => ToSqlOutput::Arg(i), #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), }) diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index 4ee1693..333a3bf 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -704,7 +704,7 @@ impl Context { #[inline] pub fn set_result(&mut self, value: &T) -> Result<()> { let t = value.to_sql()?; - unsafe { set_result(self.0, &t) }; + unsafe { set_result(self.0, &[], &t) }; Ok(()) }