diff --git a/src/context.rs b/src/context.rs index bcaefc9..6a8bb79 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,5 +1,6 @@ //! Code related to `sqlite3_context` common to `functions` and `vtab` modules. +use libsqlite3_sys::sqlite3_value; use std::os::raw::{c_int, c_void}; #[cfg(feature = "array")] use std::rc::Rc; @@ -16,7 +17,11 @@ use crate::vtab::array::{free_array, ARRAY_TYPE}; // is often known to the compiler, and thus const prop/DCE can substantially // simplify the function. #[inline] -pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput<'_>) { +pub(super) unsafe fn set_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + result: &ToSqlOutput<'_>, +) { let value = match *result { ToSqlOutput::Borrowed(v) => v, ToSqlOutput::Owned(ref v) => ValueRef::from(v), @@ -26,6 +31,10 @@ pub(super) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput< // TODO sqlite3_result_zeroblob64 // 3.8.11 return ffi::sqlite3_result_zeroblob(ctx, len); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => { + return ffi::sqlite3_result_value(ctx, args[i]); + } #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => { return ffi::sqlite3_result_pointer( diff --git a/src/functions.rs b/src/functions.rs index 0b8c2b3..f4a508c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -149,6 +149,15 @@ impl Context<'_> { unsafe { ValueRef::from_value(arg) } } + /// Returns the `idx`th argument as a `SqlFnArg`. + /// To be used when the SQL function result is one of its arguments. + #[inline] + #[must_use] + pub fn get_arg(&self, idx: usize) -> SqlFnArg { + assert!(idx < self.len()); + SqlFnArg { idx } + } + /// Returns the subtype of `idx`th argument. /// /// # Failure @@ -275,12 +284,26 @@ impl SqlFnOutput for (T, SubType) { } } -unsafe fn sql_result(ctx: *mut sqlite3_context, r: Result) { +/// n-th arg of an SQL scalar function +pub struct SqlFnArg { + idx: usize, +} +impl ToSql for SqlFnArg { + fn to_sql(&self) -> Result> { + Ok(ToSqlOutput::Arg(self.idx)) + } +} + +unsafe fn sql_result( + ctx: *mut sqlite3_context, + args: &[*mut sqlite3_value], + r: Result, +) { let t = r.as_ref().map(SqlFnOutput::to_sql); match t { Ok(Ok((ref value, sub_type))) => { - set_result(ctx, value); + set_result(ctx, args, value); if let Some(sub_type) = sub_type { ffi::sqlite3_result_subtype(ctx, sub_type); } @@ -514,13 +537,11 @@ impl InnerConnection { F: FnMut(&Context<'_>) -> Result, T: SqlFnOutput, { + let args = slice::from_raw_parts(argv, argc as usize); let r = catch_unwind(|| { let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::(); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - let ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), - }; + let ctx = Context { ctx, args }; (*boxed_f)(&ctx) }); let t = match r { @@ -530,7 +551,7 @@ impl InnerConnection { } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, args, t); } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); @@ -767,7 +788,7 @@ where } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, &[], t); } #[cfg(feature = "window")] @@ -799,7 +820,7 @@ where } Ok(r) => r, }; - sql_result(ctx, t); + sql_result(ctx, &[], t); } #[cfg(test)] @@ -809,8 +830,8 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; - use crate::{Connection, Error, Result, ValueRef}; + use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType}; + use crate::{Connection, Error, Result}; fn half(ctx: &Context<'_>) -> Result { assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); @@ -1093,9 +1114,9 @@ mod test { fn test_getsubtype(ctx: &Context<'_>) -> Result { Ok(ctx.get_subtype(0) as i32) } - fn test_setsubtype<'a>(ctx: &'a Context<'_>) -> Result<(ValueRef<'a>, SubType)> { + fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> { use std::os::raw::c_uint; - let value = ctx.get_raw(0); + let value = ctx.get_arg(0); let sub_type = ctx.get::(1)?; Ok((value, Some(sub_type))) } diff --git a/src/pragma.rs b/src/pragma.rs index 46bbde1..f1c5049 100644 --- a/src/pragma.rs +++ b/src/pragma.rs @@ -70,6 +70,13 @@ impl Sql { Some(format!("Unsupported value \"{value:?}\"")), )); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(_) => { return Err(Error::SqliteFailure( diff --git a/src/statement.rs b/src/statement.rs index edc1870..ac8f1db 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -606,6 +606,13 @@ impl Statement<'_> { .conn .decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, len) }); } + #[cfg(feature = "functions")] + ToSqlOutput::Arg(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{value:?}\"")), + )); + } #[cfg(feature = "array")] ToSqlOutput::Array(a) => { return self.conn.decode_result(unsafe { diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs index 855d339..a5cb1ff 100644 --- a/src/types/to_sql.rs +++ b/src/types/to_sql.rs @@ -22,6 +22,11 @@ pub enum ToSqlOutput<'a> { #[cfg_attr(docsrs, doc(cfg(feature = "blob")))] ZeroBlob(i32), + /// n-th arg of an SQL scalar function + #[cfg(feature = "functions")] + #[cfg_attr(docsrs, doc(cfg(feature = "functions")))] + Arg(usize), + /// `feature = "array"` #[cfg(feature = "array")] #[cfg_attr(docsrs, doc(cfg(feature = "array")))] @@ -107,6 +112,8 @@ impl ToSql for ToSqlOutput<'_> { #[cfg(feature = "blob")] ToSqlOutput::ZeroBlob(i) => ToSqlOutput::ZeroBlob(i), + #[cfg(feature = "functions")] + ToSqlOutput::Arg(i) => ToSqlOutput::Arg(i), #[cfg(feature = "array")] ToSqlOutput::Array(ref a) => ToSqlOutput::Array(a.clone()), }) @@ -292,13 +299,6 @@ impl ToSql for Value { } } -impl<'a> ToSql for ValueRef<'a> { - #[inline] - fn to_sql(&self) -> Result> { - Ok(ToSqlOutput::Borrowed(*self)) - } -} - impl ToSql for Option { #[inline] fn to_sql(&self) -> Result> { diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index 4ee1693..333a3bf 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -704,7 +704,7 @@ impl Context { #[inline] pub fn set_result(&mut self, value: &T) -> Result<()> { let t = value.to_sql()?; - unsafe { set_result(self.0, &t) }; + unsafe { set_result(self.0, &[], &t) }; Ok(()) }