Introduce SqlFnOutput trait

To keep compatibility with existing code
This commit is contained in:
gwenn 2024-01-20 17:08:15 +01:00
parent 796358a312
commit 7ed8e0ef2f
2 changed files with 60 additions and 58 deletions

View File

@ -34,7 +34,7 @@
//! regexp.is_match(text) //! regexp.is_match(text)
//! }; //! };
//! //!
//! Ok((is_match, None)) //! Ok(is_match)
//! }, //! },
//! ) //! )
//! } //! }
@ -66,7 +66,7 @@ use crate::ffi::sqlite3_context;
use crate::ffi::sqlite3_value; use crate::ffi::sqlite3_value;
use crate::context::set_result; use crate::context::set_result;
use crate::types::{FromSql, FromSqlError, ToSql, ValueRef}; use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
use crate::{str_to_cstring, Connection, Error, InnerConnection, Result}; use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
@ -256,6 +256,33 @@ type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
/// Subtype of an SQL function /// Subtype of an SQL function
pub type SubType = Option<std::os::raw::c_uint>; pub type SubType = Option<std::os::raw::c_uint>;
/// Result of an SQL function
pub trait SqlFnOutput {
/// Converts Rust value to SQLite value with an optional sub-type
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
}
impl<T: ToSql> SqlFnOutput for T {
fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
ToSql::to_sql(self).map(|o| (o, None))
}
}
unsafe fn sql_result<T: SqlFnOutput>(ctx: *mut sqlite3_context, r: Result<T>) {
let t = r.as_ref().map(SqlFnOutput::to_sql);
match t {
Ok(Ok((ref value, sub_type))) => {
set_result(ctx, value);
if let Some(sub_type) = sub_type {
ffi::sqlite3_result_subtype(ctx, sub_type);
}
}
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
};
}
/// Aggregate is the callback interface for user-defined /// Aggregate is the callback interface for user-defined
/// aggregate function. /// aggregate function.
/// ///
@ -264,7 +291,7 @@ pub type SubType = Option<std::os::raw::c_uint>;
pub trait Aggregate<A, T> pub trait Aggregate<A, T>
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
T: ToSql, T: SqlFnOutput,
{ {
/// Initializes the aggregation context. Will be called prior to the first /// Initializes the aggregation context. Will be called prior to the first
/// call to [`step()`](Aggregate::step) to set up the context for an /// call to [`step()`](Aggregate::step) to set up the context for an
@ -285,7 +312,7 @@ where
/// given `None`. /// given `None`.
/// ///
/// The passed context will have no arguments. /// The passed context will have no arguments.
fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<(T, SubType)>; fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
} }
/// `WindowAggregate` is the callback interface for /// `WindowAggregate` is the callback interface for
@ -295,11 +322,11 @@ where
pub trait WindowAggregate<A, T>: Aggregate<A, T> pub trait WindowAggregate<A, T>: Aggregate<A, T>
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
T: ToSql, T: SqlFnOutput,
{ {
/// Returns the current value of the aggregate. Unlike xFinal, the /// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context. /// implementation should not delete any context.
fn value(&self, acc: Option<&mut A>) -> Result<(T, SubType)>; fn value(&self, acc: Option<&mut A>) -> Result<T>;
/// Removes a row from the current window. /// Removes a row from the current window.
fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>; fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
@ -365,7 +392,7 @@ impl Connection {
/// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
/// |ctx| { /// |ctx| {
/// let value = ctx.get::<f64>(0)?; /// let value = ctx.get::<f64>(0)?;
/// Ok((value / 2f64, None)) /// Ok(value / 2f64)
/// }, /// },
/// )?; /// )?;
/// ///
@ -387,8 +414,8 @@ impl Connection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql, T: SqlFnOutput,
{ {
self.db self.db
.borrow_mut() .borrow_mut()
@ -412,7 +439,7 @@ impl Connection {
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static, D: Aggregate<A, T> + 'static,
T: ToSql, T: SqlFnOutput,
{ {
self.db self.db
.borrow_mut() .borrow_mut()
@ -469,16 +496,16 @@ impl InnerConnection {
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
F: FnMut(&Context<'_>) -> Result<(T, SubType)> + Send + UnwindSafe + 'static, F: FnMut(&Context<'_>) -> Result<T> + Send + UnwindSafe + 'static,
T: ToSql, T: SqlFnOutput,
{ {
unsafe extern "C" fn call_boxed_closure<F, T>( unsafe extern "C" fn call_boxed_closure<F, T>(
ctx: *mut sqlite3_context, ctx: *mut sqlite3_context,
argc: c_int, argc: c_int,
argv: *mut *mut sqlite3_value, argv: *mut *mut sqlite3_value,
) where ) where
F: FnMut(&Context<'_>) -> Result<(T, SubType)>, F: FnMut(&Context<'_>) -> Result<T>,
T: ToSql, T: SqlFnOutput,
{ {
let r = catch_unwind(|| { let r = catch_unwind(|| {
let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>(); let boxed_f: *mut F = ffi::sqlite3_user_data(ctx).cast::<F>();
@ -496,18 +523,7 @@ impl InnerConnection {
} }
Ok(r) => r, Ok(r) => r,
}; };
let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); sql_result(ctx, t);
match t {
Ok((Ok(ref value), sub_type)) => {
set_result(ctx, value);
if let Some(sub_type) = sub_type {
ffi::sqlite3_result_subtype(ctx, *sub_type);
}
}
Ok((Err(err), _)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
} }
let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
@ -538,7 +554,7 @@ impl InnerConnection {
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T> + 'static, D: Aggregate<A, T> + 'static,
T: ToSql, T: SqlFnOutput,
{ {
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?; let c_name = str_to_cstring(fn_name)?;
@ -624,7 +640,7 @@ unsafe extern "C" fn call_boxed_step<A, D, T>(
) where ) where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: SqlFnOutput,
{ {
let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) { let pac = if let Some(pac) = aggregate_context(ctx, std::mem::size_of::<*mut A>()) {
pac pac
@ -710,7 +726,7 @@ unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: SqlFnOutput,
{ {
// Within the xFinal callback, it is customary to set N=0 in calls to // Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
@ -744,17 +760,7 @@ where
} }
Ok(r) => r, Ok(r) => r,
}; };
let t = t.as_ref().map(|(t, sub_type)| (ToSql::to_sql(t), sub_type)); sql_result(ctx, t);
match t {
Ok((Ok(ref value), sub_type)) => {
set_result(ctx, value);
if let Some(sub_type) = sub_type {
ffi::sqlite3_result_subtype(ctx, *sub_type);
}
}
Ok((Err(err), _)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
} }
#[cfg(feature = "window")] #[cfg(feature = "window")]
@ -806,13 +812,13 @@ mod test {
#[cfg(feature = "window")] #[cfg(feature = "window")]
use crate::functions::WindowAggregate; use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context, FunctionFlags, SubType}; use crate::functions::{Aggregate, Context, FunctionFlags};
use crate::{Connection, Error, Result}; use crate::{Connection, Error, Result};
fn half(ctx: &Context<'_>) -> Result<(c_double, SubType)> { fn half(ctx: &Context<'_>) -> Result<c_double> {
assert_eq!(ctx.len(), 1, "called with unexpected number of arguments"); assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
let value = ctx.get::<c_double>(0)?; let value = ctx.get::<c_double>(0)?;
Ok((value / 2f64, None)) Ok(value / 2f64)
} }
#[test] #[test]
@ -851,7 +857,7 @@ mod test {
// This implementation of a regexp scalar function uses SQLite's auxiliary data // This implementation of a regexp scalar function uses SQLite's auxiliary data
// (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
// expression multiple times within one query. // expression multiple times within one query.
fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<(bool, SubType)> { fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>; type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
let regexp: std::sync::Arc<Regex> = ctx let regexp: std::sync::Arc<Regex> = ctx
@ -868,7 +874,7 @@ mod test {
regexp.is_match(text) regexp.is_match(text)
}; };
Ok((is_match, None)) Ok(is_match)
} }
#[test] #[test]
@ -915,7 +921,7 @@ mod test {
ret.push_str(&s); ret.push_str(&s);
} }
Ok((ret, None)) Ok(ret)
}, },
)?; )?;
@ -940,7 +946,7 @@ mod test {
assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType)); assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100); assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
} }
Ok((true, None)) Ok(true)
})?; })?;
let res: bool = let res: bool =
@ -963,12 +969,8 @@ mod test {
Ok(()) Ok(())
} }
fn finalize( fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
&self, Ok(sum)
_: &mut Context<'_>,
sum: Option<i64>,
) -> Result<(Option<i64>, SubType)> {
Ok((sum, None))
} }
} }
@ -982,8 +984,8 @@ mod test {
Ok(()) Ok(())
} }
fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<(i64, SubType)> { fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
Ok((sum.unwrap_or(0), None)) Ok(sum.unwrap_or(0))
} }
} }
@ -1041,8 +1043,8 @@ mod test {
Ok(()) Ok(())
} }
fn value(&self, sum: Option<&mut i64>) -> Result<(Option<i64>, SubType)> { fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
Ok((sum.copied(), None)) Ok(sum.copied())
} }
} }

View File

@ -1801,7 +1801,7 @@ mod test {
functions::FunctionFlags::default(), functions::FunctionFlags::default(),
move |_| { move |_| {
interrupt_handle.interrupt(); interrupt_handle.interrupt();
Ok((0, None)) Ok(0)
}, },
)?; )?;