2020-04-06 13:15:27 +08:00
//! `feature = "functions"` Create or redefine SQL functions.
2015-12-12 05:27:39 +08:00
//!
//! # Example
//!
2018-08-17 00:29:46 +08:00
//! Adding a `regexp` function to a connection in which compiled regular
//! expressions are cached in a `HashMap`. For an alternative implementation
2021-05-02 19:46:04 +08:00
//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
2018-08-17 00:29:46 +08:00
//! to avoid recompiling regular expressions, see the unit tests for this
//! module.
2015-12-12 05:27:39 +08:00
//!
//! ```rust
2018-08-17 00:29:46 +08:00
//! use regex::Regex;
2020-01-27 01:11:11 +08:00
//! use rusqlite::functions::FunctionFlags;
2020-11-03 17:32:46 +08:00
//! use rusqlite::{Connection, Error, Result};
2020-04-13 10:41:01 +08:00
//! use std::sync::Arc;
//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
2015-12-12 05:27:39 +08:00
//!
2015-12-13 03:06:03 +08:00
//! fn add_regexp_function(db: &Connection) -> Result<()> {
2020-01-27 01:11:11 +08:00
//! db.create_scalar_function(
//! "regexp",
//! 2,
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
2020-10-29 04:12:29 +08:00
//! let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
//! Ok(Regex::new(vr.as_str()?)?)
//! })?;
2020-01-27 01:11:11 +08:00
//! let is_match = {
//! let text = ctx
//! .get_raw(1)
//! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?;
2019-06-02 15:04:47 +08:00
//!
2020-04-13 10:41:01 +08:00
//! regexp.is_match(text)
2020-01-27 01:11:11 +08:00
//! };
2015-12-12 05:27:39 +08:00
//!
2020-01-27 01:11:11 +08:00
//! Ok(is_match)
//! },
//! )
2015-12-12 05:27:39 +08:00
//! }
//!
2019-02-09 14:16:05 +08:00
//! fn main() -> Result<()> {
//! let db = Connection::open_in_memory()?;
//! add_regexp_function(&db)?;
2015-12-12 05:27:39 +08:00
//!
2020-11-04 10:41:51 +08:00
//! let is_match: bool =
//! db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
//! row.get(0)
//! })?;
2015-12-12 05:27:39 +08:00
//!
//! assert!(is_match);
2019-02-09 14:16:05 +08:00
//! Ok(())
2015-12-12 05:27:39 +08:00
//! }
//! ```
2020-04-13 10:41:01 +08:00
use std ::any ::Any ;
2020-12-18 20:54:38 +08:00
use std ::marker ::PhantomData ;
use std ::ops ::Deref ;
2018-08-11 19:37:56 +08:00
use std ::os ::raw ::{ c_int , c_void } ;
2018-12-16 16:40:14 +08:00
use std ::panic ::{ catch_unwind , RefUnwindSafe , UnwindSafe } ;
2015-12-12 04:47:52 +08:00
use std ::ptr ;
2015-12-12 03:46:28 +08:00
use std ::slice ;
2020-04-13 10:41:01 +08:00
use std ::sync ::Arc ;
2015-08-09 15:52:53 +08:00
2018-10-31 03:11:35 +08:00
use crate ::ffi ;
use crate ::ffi ::sqlite3_context ;
use crate ::ffi ::sqlite3_value ;
2015-08-09 15:52:53 +08:00
2018-10-31 03:11:35 +08:00
use crate ::context ::set_result ;
use crate ::types ::{ FromSql , FromSqlError , ToSql , ValueRef } ;
2015-08-09 15:52:53 +08:00
2018-10-31 03:11:35 +08:00
use crate ::{ str_to_cstring , Connection , Error , InnerConnection , Result } ;
2015-08-09 15:52:53 +08:00
2018-06-10 18:16:54 +08:00
unsafe fn report_error ( ctx : * mut sqlite3_context , err : & Error ) {
2018-08-17 00:29:46 +08:00
// Extended constraint error codes were added in SQLite 3.7.16. We don't have
// an explicit feature check for that, and this doesn't really warrant one.
// We'll use the extended code if we're on the bundled version (since it's
// at least 3.17.0) and the normal constraint error code if not.
2020-01-15 00:11:36 +08:00
#[ cfg(feature = " modern_sqlite " ) ]
2018-06-10 18:16:54 +08:00
fn constraint_error_code ( ) -> i32 {
ffi ::SQLITE_CONSTRAINT_FUNCTION
}
2020-01-15 00:11:36 +08:00
#[ cfg(not(feature = " modern_sqlite " )) ]
2018-06-10 18:16:54 +08:00
fn constraint_error_code ( ) -> i32 {
ffi ::SQLITE_CONSTRAINT
}
match * err {
Error ::SqliteFailure ( ref err , ref s ) = > {
ffi ::sqlite3_result_error_code ( ctx , err . extended_code ) ;
if let Some ( Ok ( cstr ) ) = s . as_ref ( ) . map ( | s | str_to_cstring ( s ) ) {
ffi ::sqlite3_result_error ( ctx , cstr . as_ptr ( ) , - 1 ) ;
}
}
_ = > {
ffi ::sqlite3_result_error_code ( ctx , constraint_error_code ( ) ) ;
2020-01-26 23:57:58 +08:00
if let Ok ( cstr ) = str_to_cstring ( & err . to_string ( ) ) {
2018-06-10 18:16:54 +08:00
ffi ::sqlite3_result_error ( ctx , cstr . as_ptr ( ) , - 1 ) ;
}
}
}
}
2015-12-12 02:54:08 +08:00
unsafe extern " C " fn free_boxed_value < T > ( p : * mut c_void ) {
2018-08-11 17:14:17 +08:00
drop ( Box ::from_raw ( p as * mut T ) ) ;
2015-12-12 02:54:08 +08:00
}
2020-04-06 13:15:27 +08:00
/// `feature = "functions"` Context is a wrapper for the SQLite function
/// evaluation context.
2015-12-12 03:46:28 +08:00
pub struct Context < ' a > {
2015-12-12 04:08:40 +08:00
ctx : * mut sqlite3_context ,
2015-12-12 03:46:28 +08:00
args : & ' a [ * mut sqlite3_value ] ,
2015-12-12 02:54:08 +08:00
}
2019-02-03 18:02:38 +08:00
impl Context < '_ > {
2015-12-12 05:27:39 +08:00
/// Returns the number of arguments to the function.
2020-11-04 11:10:23 +08:00
#[ inline ]
2015-12-12 03:46:28 +08:00
pub fn len ( & self ) -> usize {
self . args . len ( )
}
2018-08-17 00:29:46 +08:00
2016-02-14 23:11:59 +08:00
/// Returns `true` when there is no argument.
2020-11-04 11:10:23 +08:00
#[ inline ]
2016-02-14 23:11:59 +08:00
pub fn is_empty ( & self ) -> bool {
self . args . is_empty ( )
}
2015-12-12 03:46:28 +08:00
2015-12-12 05:27:39 +08:00
/// Returns the `idx`th argument as a `T`.
///
/// # Failure
///
2021-04-03 17:03:50 +08:00
/// Will panic if `idx` is greater than or equal to
/// [`self.len()`](Context::len).
2015-12-12 05:27:39 +08:00
///
2018-08-17 00:29:46 +08:00
/// Will return Err if the underlying SQLite type cannot be converted to a
/// `T`.
2016-05-25 08:08:12 +08:00
pub fn get < T : FromSql > ( & self , idx : usize ) -> Result < T > {
2015-12-12 03:46:28 +08:00
let arg = self . args [ idx ] ;
2016-05-25 09:34:18 +08:00
let value = unsafe { ValueRef ::from_value ( arg ) } ;
2016-05-25 08:08:12 +08:00
FromSql ::column_result ( value ) . map_err ( | err | match err {
2018-08-11 18:48:21 +08:00
FromSqlError ::InvalidType = > {
2016-06-03 03:03:25 +08:00
Error ::InvalidFunctionParameterType ( idx , value . data_type ( ) )
}
2018-08-11 18:48:21 +08:00
FromSqlError ::OutOfRange ( i ) = > Error ::IntegralValueOutOfRange ( idx , i ) ,
FromSqlError ::Other ( err ) = > {
2016-06-03 03:03:25 +08:00
Error ::FromSqlConversionFailure ( idx , value . data_type ( ) , err )
}
2018-12-16 15:19:54 +08:00
#[ cfg(feature = " i128_blob " ) ]
FromSqlError ::InvalidI128Size ( _ ) = > {
Error ::FromSqlConversionFailure ( idx , value . data_type ( ) , Box ::new ( err ) )
}
2019-04-26 22:15:07 +08:00
#[ cfg(feature = " uuid " ) ]
FromSqlError ::InvalidUuidSize ( _ ) = > {
Error ::FromSqlConversionFailure ( idx , value . data_type ( ) , Box ::new ( err ) )
}
2018-08-11 18:48:21 +08:00
} )
2015-12-12 03:46:28 +08:00
}
2018-10-19 02:59:30 +08:00
/// Returns the `idx`th argument as a `ValueRef`.
///
/// # Failure
///
2021-04-03 17:03:50 +08:00
/// Will panic if `idx` is greater than or equal to
/// [`self.len()`](Context::len).
2020-11-04 11:10:23 +08:00
#[ inline ]
2019-02-03 18:02:38 +08:00
pub fn get_raw ( & self , idx : usize ) -> ValueRef < '_ > {
2018-10-19 02:59:30 +08:00
let arg = self . args [ idx ] ;
unsafe { ValueRef ::from_value ( arg ) }
}
2021-05-02 19:46:04 +08:00
/// Fetch or insert the auxiliary data associated with a particular
2020-04-13 10:46:17 +08:00
/// parameter. This is intended to be an easier-to-use way of fetching it
2021-04-03 17:03:50 +08:00
/// compared to calling [`get_aux`](Context::get_aux) and
/// [`set_aux`](Context::set_aux) separately.
2020-04-13 10:46:17 +08:00
///
2020-11-07 19:32:41 +08:00
/// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
2020-04-13 10:46:17 +08:00
/// this feature, or the unit tests of this module for an example.
2020-04-13 10:41:01 +08:00
pub fn get_or_create_aux < T , E , F > ( & self , arg : c_int , func : F ) -> Result < Arc < T > >
where
T : Send + Sync + 'static ,
E : Into < Box < dyn std ::error ::Error + Send + Sync + 'static > > ,
F : FnOnce ( ValueRef < '_ > ) -> Result < T , E > ,
{
if let Some ( v ) = self . get_aux ( arg ) ? {
Ok ( v )
} else {
let vr = self . get_raw ( arg as usize ) ;
self . set_aux (
arg ,
func ( vr ) . map_err ( | e | Error ::UserFunctionError ( e . into ( ) ) ) ? ,
)
}
}
2021-05-02 19:46:04 +08:00
/// Sets the auxiliary data associated with a particular parameter. See
2020-11-07 19:32:41 +08:00
/// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
2015-12-12 05:27:39 +08:00
/// this feature, or the unit tests of this module for an example.
2020-04-13 10:41:01 +08:00
pub fn set_aux < T : Send + Sync + 'static > ( & self , arg : c_int , value : T ) -> Result < Arc < T > > {
let orig : Arc < T > = Arc ::new ( value ) ;
let inner : AuxInner = orig . clone ( ) ;
let outer = Box ::new ( inner ) ;
let raw : * mut AuxInner = Box ::into_raw ( outer ) ;
2015-12-12 02:54:08 +08:00
unsafe {
2018-08-11 18:48:21 +08:00
ffi ::sqlite3_set_auxdata (
self . ctx ,
arg ,
2020-04-13 10:41:01 +08:00
raw as * mut _ ,
Some ( free_boxed_value ::< AuxInner > ) ,
2018-08-11 18:48:21 +08:00
)
2015-12-12 02:54:08 +08:00
} ;
2020-04-13 10:41:01 +08:00
Ok ( orig )
2015-12-12 02:54:08 +08:00
}
2021-05-02 19:46:04 +08:00
/// Gets the auxiliary data that was associated with a given parameter via
2021-04-03 17:03:50 +08:00
/// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
/// associated, and Ok(Some(v)) if it has. Returns an error if the
/// requested type does not match.
2020-04-13 10:41:01 +08:00
pub fn get_aux < T : Send + Sync + 'static > ( & self , arg : c_int ) -> Result < Option < Arc < T > > > {
let p = unsafe { ffi ::sqlite3_get_auxdata ( self . ctx , arg ) as * const AuxInner } ;
2018-08-11 18:48:21 +08:00
if p . is_null ( ) {
2019-01-25 16:43:50 +08:00
Ok ( None )
2018-08-11 18:48:21 +08:00
} else {
2020-04-13 10:41:01 +08:00
let v : AuxInner = AuxInner ::clone ( unsafe { & * p } ) ;
v . downcast ::< T > ( )
. map ( Some )
. map_err ( | _ | Error ::GetAuxWrongType )
2018-08-11 18:48:21 +08:00
}
2015-12-12 02:54:08 +08:00
}
2020-12-18 20:24:03 +08:00
2020-12-19 06:11:23 +08:00
/// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
///
/// # Safety
2020-12-18 20:24:03 +08:00
///
/// This function is marked unsafe because there is a potential for other
2020-12-19 06:11:23 +08:00
/// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
2020-12-18 20:24:03 +08:00
pub unsafe fn get_connection ( & self ) -> Result < ConnectionRef < '_ > > {
let handle = ffi ::sqlite3_context_db_handle ( self . ctx ) ;
Ok ( ConnectionRef {
conn : Connection ::from_handle ( handle ) ? ,
2020-12-18 20:54:38 +08:00
phantom : PhantomData ,
2020-12-18 20:24:03 +08:00
} )
}
}
/// A reference to a connection handle with a lifetime bound to something.
pub struct ConnectionRef < ' ctx > {
2020-12-18 20:54:38 +08:00
// comes from Connection::from_handle(sqlite3_context_db_handle(...))
2020-12-18 20:24:03 +08:00
// and is non-owning
conn : Connection ,
phantom : PhantomData < & ' ctx Context < ' ctx > > ,
2015-12-12 02:54:08 +08:00
}
2020-12-18 20:24:03 +08:00
impl Deref for ConnectionRef < '_ > {
type Target = Connection ;
#[ inline ]
fn deref ( & self ) -> & Connection {
& self . conn
}
}
2020-04-13 10:41:01 +08:00
type AuxInner = Arc < dyn Any + Send + Sync + 'static > ;
2020-04-13 02:17:56 +08:00
2020-04-06 13:15:27 +08:00
/// `feature = "functions"` Aggregate is the callback interface for user-defined
/// aggregate function.
2015-12-21 02:27:28 +08:00
///
2018-08-17 00:29:46 +08:00
/// `A` is the type of the aggregation context and `T` is the type of the final
/// result. Implementations should be stateless.
2016-05-04 03:00:59 +08:00
pub trait Aggregate < A , T >
2018-08-11 18:48:21 +08:00
where
2018-12-16 16:40:14 +08:00
A : RefUnwindSafe + UnwindSafe ,
2018-08-11 18:48:21 +08:00
T : ToSql ,
2016-05-04 03:00:59 +08:00
{
2018-08-17 00:29:46 +08:00
/// Initializes the aggregation context. Will be called prior to the first
2021-04-03 17:03:50 +08:00
/// call to [`step()`](Aggregate::step) to set up the context for an
/// invocation of the function. (Note: `init()` will not be called if
/// there are no rows.)
2020-12-19 19:12:30 +08:00
fn init ( & self , _ : & mut Context < '_ > ) -> Result < A > ;
2016-01-08 01:33:32 +08:00
2018-08-17 00:29:46 +08:00
/// "step" function called once for each row in an aggregate group. May be
/// called 0 times if there are no rows.
2018-12-08 04:57:04 +08:00
fn step ( & self , _ : & mut Context < '_ > , _ : & mut A ) -> Result < ( ) > ;
2016-01-08 01:33:32 +08:00
2018-08-17 00:29:46 +08:00
/// Computes and returns the final result. Will be called exactly once for
2021-04-03 17:03:50 +08:00
/// each invocation of the function. If [`step()`](Aggregate::step) was
/// called at least once, will be given `Some(A)` (the same `A` as was
/// created by [`init`](Aggregate::init) and given to
/// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
/// called (because the function is running against 0 rows), will be
/// given `None`.
2020-12-18 20:10:35 +08:00
///
/// The passed context will have no arguments.
fn finalize ( & self , _ : & mut Context < '_ > , _ : Option < A > ) -> Result < T > ;
2015-12-16 03:54:23 +08:00
}
2020-04-06 13:15:27 +08:00
/// `feature = "window"` WindowAggregate is the callback interface for
/// user-defined aggregate window function.
2019-06-26 02:33:49 +08:00
#[ cfg(feature = " window " ) ]
pub trait WindowAggregate < A , T > : Aggregate < A , T >
where
A : RefUnwindSafe + UnwindSafe ,
T : ToSql ,
{
/// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context.
fn value ( & self , _ : Option < & A > ) -> Result < T > ;
/// Removes a row from the current window.
fn inverse ( & self , _ : & mut Context < '_ > , _ : & mut A ) -> Result < ( ) > ;
}
2020-01-27 01:11:11 +08:00
bitflags ::bitflags! {
2020-05-17 17:21:10 +08:00
/// Function Flags.
/// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
/// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
2020-01-27 01:11:11 +08:00
#[ repr(C) ]
pub struct FunctionFlags : ::std ::os ::raw ::c_int {
2020-05-17 17:21:10 +08:00
/// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
2020-01-27 01:11:11 +08:00
const SQLITE_UTF8 = ffi ::SQLITE_UTF8 ;
2020-05-17 17:21:10 +08:00
/// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
2020-01-27 01:11:11 +08:00
const SQLITE_UTF16LE = ffi ::SQLITE_UTF16LE ;
2020-05-17 17:21:10 +08:00
/// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
2020-01-27 01:11:11 +08:00
const SQLITE_UTF16BE = ffi ::SQLITE_UTF16BE ;
2020-05-17 17:21:10 +08:00
/// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
2020-01-27 01:11:11 +08:00
const SQLITE_UTF16 = ffi ::SQLITE_UTF16 ;
2020-05-17 17:21:10 +08:00
/// Means that the function always gives the same output when the input parameters are the same.
2020-01-27 01:11:11 +08:00
const SQLITE_DETERMINISTIC = ffi ::SQLITE_DETERMINISTIC ;
2020-05-17 17:21:10 +08:00
/// Means that the function may only be invoked from top-level SQL.
2020-01-27 01:11:11 +08:00
const SQLITE_DIRECTONLY = 0x0000_0008_0000 ; // 3.30.0
2020-05-17 17:21:10 +08:00
/// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the sub-types of its arguments.
2020-01-27 01:11:11 +08:00
const SQLITE_SUBTYPE = 0x0000_0010_0000 ; // 3.30.0
2020-05-17 17:21:10 +08:00
/// Means that the function is unlikely to cause problems even if misused.
2020-01-27 01:11:11 +08:00
const SQLITE_INNOCUOUS = 0x0000_0020_0000 ; // 3.31.0
}
}
impl Default for FunctionFlags {
2020-11-04 11:10:23 +08:00
#[ inline ]
2020-01-27 01:11:11 +08:00
fn default ( ) -> FunctionFlags {
FunctionFlags ::SQLITE_UTF8
}
}
2015-12-13 02:50:12 +08:00
impl Connection {
2020-04-06 13:15:27 +08:00
/// `feature = "functions"` Attach a user-defined scalar function to
/// this database connection.
2015-12-12 05:27:39 +08:00
///
/// `fn_name` is the name the function will be accessible from SQL.
2018-08-17 00:29:46 +08:00
/// `n_arg` is the number of arguments to the function. Use `-1` for a
/// variable number. If the function always returns the same value
/// given the same input, `deterministic` should be `true`.
2015-12-12 05:27:39 +08:00
///
/// The function will remain available until the connection is closed or
2021-04-03 17:03:50 +08:00
/// until it is explicitly removed via
/// [`remove_function`](Connection::remove_function).
2015-12-12 05:27:39 +08:00
///
/// # Example
///
/// ```rust
2020-11-03 15:34:08 +08:00
/// # use rusqlite::{Connection, Result};
2020-01-27 01:11:11 +08:00
/// # use rusqlite::functions::FunctionFlags;
2015-12-13 03:06:03 +08:00
/// fn scalar_function_example(db: Connection) -> Result<()> {
2020-01-27 01:11:11 +08:00
/// db.create_scalar_function(
/// "halve",
/// 1,
/// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
/// |ctx| {
/// let value = ctx.get::<f64>(0)?;
/// Ok(value / 2f64)
/// },
/// )?;
2015-12-12 05:27:39 +08:00
///
2020-11-03 15:34:08 +08:00
/// let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
2015-12-12 05:27:39 +08:00
/// assert_eq!(six_halved, 3f64);
/// Ok(())
/// }
/// ```
///
/// # Failure
///
/// Will return Err if the function could not be attached to the connection.
2020-11-04 11:10:23 +08:00
#[ inline ]
2020-10-29 02:20:05 +08:00
pub fn create_scalar_function < ' c , F , T > (
& ' c self ,
2018-08-11 18:48:21 +08:00
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2018-08-11 18:48:21 +08:00
x_func : F ,
) -> Result < ( ) >
where
2020-10-29 02:20:05 +08:00
F : FnMut ( & Context < '_ > ) -> Result < T > + Send + UnwindSafe + ' c ,
2018-08-11 18:48:21 +08:00
T : ToSql ,
2015-12-12 01:01:05 +08:00
{
2017-04-08 01:43:24 +08:00
self . db
. borrow_mut ( )
2020-01-27 01:11:11 +08:00
. create_scalar_function ( fn_name , n_arg , flags , x_func )
2015-08-09 15:52:53 +08:00
}
2015-12-12 04:47:52 +08:00
2020-04-06 13:15:27 +08:00
/// `feature = "functions"` Attach a user-defined aggregate function to this
/// database connection.
2015-12-21 02:27:28 +08:00
///
/// # Failure
///
/// Will return Err if the function could not be attached to the connection.
2020-11-04 11:10:23 +08:00
#[ inline ]
2018-08-11 18:48:21 +08:00
pub fn create_aggregate_function < A , D , T > (
& self ,
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2018-08-11 18:48:21 +08:00
aggr : D ,
) -> Result < ( ) >
where
2018-12-16 16:40:14 +08:00
A : RefUnwindSafe + UnwindSafe ,
2018-08-11 18:48:21 +08:00
D : Aggregate < A , T > ,
T : ToSql ,
2015-12-16 03:54:23 +08:00
{
self . db
. borrow_mut ( )
2020-01-27 01:11:11 +08:00
. create_aggregate_function ( fn_name , n_arg , flags , aggr )
2015-12-16 03:54:23 +08:00
}
2020-04-06 13:15:27 +08:00
/// `feature = "window"` Attach a user-defined aggregate window function to
/// this database connection.
///
2020-11-07 19:32:41 +08:00
/// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
2020-04-06 13:15:27 +08:00
/// information.
2019-06-26 02:33:49 +08:00
#[ cfg(feature = " window " ) ]
2020-11-04 11:10:23 +08:00
#[ inline ]
2019-06-26 02:33:49 +08:00
pub fn create_window_function < A , W , T > (
& self ,
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2019-06-26 02:33:49 +08:00
aggr : W ,
) -> Result < ( ) >
where
A : RefUnwindSafe + UnwindSafe ,
W : WindowAggregate < A , T > ,
T : ToSql ,
{
self . db
. borrow_mut ( )
2020-01-27 01:11:11 +08:00
. create_window_function ( fn_name , n_arg , flags , aggr )
2019-06-26 02:33:49 +08:00
}
2020-04-06 13:15:27 +08:00
/// `feature = "functions"` Removes a user-defined function from this
/// database connection.
2015-12-12 05:27:39 +08:00
///
/// `fn_name` and `n_arg` should match the name and number of arguments
2021-04-03 17:03:50 +08:00
/// given to [`create_scalar_function`](Connection::create_scalar_function)
/// or [`create_aggregate_function`](Connection::create_aggregate_function).
2015-12-12 05:27:39 +08:00
///
/// # Failure
///
/// Will return Err if the function could not be removed.
2020-11-04 11:10:23 +08:00
#[ inline ]
2015-12-13 03:06:03 +08:00
pub fn remove_function ( & self , fn_name : & str , n_arg : c_int ) -> Result < ( ) > {
2015-12-12 04:47:52 +08:00
self . db . borrow_mut ( ) . remove_function ( fn_name , n_arg )
}
2015-08-09 15:52:53 +08:00
}
2015-12-13 02:50:12 +08:00
impl InnerConnection {
2020-10-29 02:20:05 +08:00
fn create_scalar_function < ' c , F , T > (
& ' c mut self ,
2018-08-11 18:48:21 +08:00
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2018-08-11 18:48:21 +08:00
x_func : F ,
) -> Result < ( ) >
where
2020-10-29 02:20:05 +08:00
F : FnMut ( & Context < '_ > ) -> Result < T > + Send + UnwindSafe + ' c ,
2018-08-11 18:48:21 +08:00
T : ToSql ,
2015-12-12 01:01:05 +08:00
{
2018-08-11 18:48:21 +08:00
unsafe extern " C " fn call_boxed_closure < F , T > (
ctx : * mut sqlite3_context ,
argc : c_int ,
argv : * mut * mut sqlite3_value ,
) where
2018-12-08 04:57:04 +08:00
F : FnMut ( & Context < '_ > ) -> Result < T > ,
2018-08-11 18:48:21 +08:00
T : ToSql ,
2015-12-12 01:01:05 +08:00
{
2018-12-16 16:40:14 +08:00
let r = catch_unwind ( | | {
let boxed_f : * mut F = ffi ::sqlite3_user_data ( ctx ) as * mut F ;
assert! ( ! boxed_f . is_null ( ) , " Internal error - null function pointer " ) ;
let ctx = Context {
ctx ,
args : slice ::from_raw_parts ( argv , argc as usize ) ,
} ;
( * boxed_f ) ( & ctx )
} ) ;
let t = match r {
Err ( _ ) = > {
report_error ( ctx , & Error ::UnwindingPanic ) ;
return ;
}
Ok ( r ) = > r ,
2015-12-13 21:40:51 +08:00
} ;
2016-05-26 12:06:53 +08:00
let t = t . as_ref ( ) . map ( | t | ToSql ::to_sql ( t ) ) ;
match t {
2018-12-16 16:40:14 +08:00
Ok ( Ok ( ref value ) ) = > set_result ( ctx , value ) ,
Ok ( Err ( err ) ) = > report_error ( ctx , & err ) ,
Err ( err ) = > report_error ( ctx , err ) ,
2015-12-12 01:01:05 +08:00
}
}
let boxed_f : * mut F = Box ::into_raw ( Box ::new ( x_func ) ) ;
2018-10-31 03:11:35 +08:00
let c_name = str_to_cstring ( fn_name ) ? ;
2015-08-09 15:52:53 +08:00
let r = unsafe {
2018-08-11 18:48:21 +08:00
ffi ::sqlite3_create_function_v2 (
self . db ( ) ,
c_name . as_ptr ( ) ,
n_arg ,
2020-01-27 01:11:11 +08:00
flags . bits ( ) ,
2018-08-11 18:48:21 +08:00
boxed_f as * mut c_void ,
Some ( call_boxed_closure ::< F , T > ) ,
None ,
None ,
Some ( free_boxed_value ::< F > ) ,
)
2015-08-09 15:52:53 +08:00
} ;
self . decode_result ( r )
}
2015-12-12 04:47:52 +08:00
2018-08-11 18:48:21 +08:00
fn create_aggregate_function < A , D , T > (
& mut self ,
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2018-08-11 18:48:21 +08:00
aggr : D ,
) -> Result < ( ) >
where
2018-12-16 16:40:14 +08:00
A : RefUnwindSafe + UnwindSafe ,
2018-08-11 18:48:21 +08:00
D : Aggregate < A , T > ,
T : ToSql ,
2015-12-16 03:54:23 +08:00
{
let boxed_aggr : * mut D = Box ::into_raw ( Box ::new ( aggr ) ) ;
2018-10-31 03:11:35 +08:00
let c_name = str_to_cstring ( fn_name ) ? ;
2015-12-16 03:54:23 +08:00
let r = unsafe {
2018-08-11 18:48:21 +08:00
ffi ::sqlite3_create_function_v2 (
self . db ( ) ,
c_name . as_ptr ( ) ,
n_arg ,
2020-01-27 01:11:11 +08:00
flags . bits ( ) ,
2018-08-11 18:48:21 +08:00
boxed_aggr as * mut c_void ,
None ,
Some ( call_boxed_step ::< A , D , T > ) ,
Some ( call_boxed_final ::< A , D , T > ) ,
Some ( free_boxed_value ::< D > ) ,
)
2015-12-16 03:54:23 +08:00
} ;
self . decode_result ( r )
}
2019-06-26 02:33:49 +08:00
#[ cfg(feature = " window " ) ]
fn create_window_function < A , W , T > (
& mut self ,
fn_name : & str ,
n_arg : c_int ,
2020-01-27 01:11:11 +08:00
flags : FunctionFlags ,
2019-06-26 02:33:49 +08:00
aggr : W ,
) -> Result < ( ) >
where
A : RefUnwindSafe + UnwindSafe ,
W : WindowAggregate < A , T > ,
T : ToSql ,
{
let boxed_aggr : * mut W = Box ::into_raw ( Box ::new ( aggr ) ) ;
let c_name = str_to_cstring ( fn_name ) ? ;
let r = unsafe {
ffi ::sqlite3_create_window_function (
self . db ( ) ,
c_name . as_ptr ( ) ,
n_arg ,
2020-01-27 01:11:11 +08:00
flags . bits ( ) ,
2019-06-26 02:33:49 +08:00
boxed_aggr as * mut c_void ,
Some ( call_boxed_step ::< A , W , T > ) ,
Some ( call_boxed_final ::< A , W , T > ) ,
Some ( call_boxed_value ::< A , W , T > ) ,
Some ( call_boxed_inverse ::< A , W , T > ) ,
Some ( free_boxed_value ::< W > ) ,
)
} ;
self . decode_result ( r )
}
2015-12-13 03:06:03 +08:00
fn remove_function ( & mut self , fn_name : & str , n_arg : c_int ) -> Result < ( ) > {
2018-10-31 03:11:35 +08:00
let c_name = str_to_cstring ( fn_name ) ? ;
2015-12-12 04:47:52 +08:00
let r = unsafe {
2018-08-11 18:48:21 +08:00
ffi ::sqlite3_create_function_v2 (
self . db ( ) ,
c_name . as_ptr ( ) ,
n_arg ,
ffi ::SQLITE_UTF8 ,
ptr ::null_mut ( ) ,
None ,
None ,
None ,
None ,
)
2015-12-12 04:47:52 +08:00
} ;
self . decode_result ( r )
}
2015-08-09 15:52:53 +08:00
}
2019-06-26 02:33:49 +08:00
unsafe fn aggregate_context < A > ( ctx : * mut sqlite3_context , bytes : usize ) -> Option < * mut * mut A > {
let pac = ffi ::sqlite3_aggregate_context ( ctx , bytes as c_int ) as * mut * mut A ;
if pac . is_null ( ) {
return None ;
}
Some ( pac )
}
unsafe extern " C " fn call_boxed_step < A , D , T > (
ctx : * mut sqlite3_context ,
argc : c_int ,
argv : * mut * mut sqlite3_value ,
) where
A : RefUnwindSafe + UnwindSafe ,
D : Aggregate < A , T > ,
T : ToSql ,
{
let pac = match aggregate_context ( ctx , ::std ::mem ::size_of ::< * mut A > ( ) ) {
Some ( pac ) = > pac ,
None = > {
ffi ::sqlite3_result_error_nomem ( ctx ) ;
return ;
}
} ;
let r = catch_unwind ( | | {
let boxed_aggr : * mut D = ffi ::sqlite3_user_data ( ctx ) as * mut D ;
assert! (
! boxed_aggr . is_null ( ) ,
" Internal error - null aggregate pointer "
) ;
let mut ctx = Context {
ctx ,
args : slice ::from_raw_parts ( argv , argc as usize ) ,
} ;
2020-12-18 18:36:35 +08:00
if ( * pac as * mut A ) . is_null ( ) {
2020-12-19 19:12:30 +08:00
* pac = Box ::into_raw ( Box ::new ( ( * boxed_aggr ) . init ( & mut ctx ) ? ) ) ;
2020-12-18 18:36:35 +08:00
}
2019-06-26 02:33:49 +08:00
( * boxed_aggr ) . step ( & mut ctx , & mut * * pac )
} ) ;
let r = match r {
Err ( _ ) = > {
report_error ( ctx , & Error ::UnwindingPanic ) ;
return ;
}
Ok ( r ) = > r ,
} ;
match r {
Ok ( _ ) = > { }
Err ( err ) = > report_error ( ctx , & err ) ,
} ;
}
#[ cfg(feature = " window " ) ]
unsafe extern " C " fn call_boxed_inverse < A , W , T > (
ctx : * mut sqlite3_context ,
argc : c_int ,
argv : * mut * mut sqlite3_value ,
) where
A : RefUnwindSafe + UnwindSafe ,
W : WindowAggregate < A , T > ,
T : ToSql ,
{
let pac = match aggregate_context ( ctx , ::std ::mem ::size_of ::< * mut A > ( ) ) {
Some ( pac ) = > pac ,
None = > {
ffi ::sqlite3_result_error_nomem ( ctx ) ;
return ;
}
} ;
let r = catch_unwind ( | | {
let boxed_aggr : * mut W = ffi ::sqlite3_user_data ( ctx ) as * mut W ;
assert! (
! boxed_aggr . is_null ( ) ,
" Internal error - null aggregate pointer "
) ;
let mut ctx = Context {
ctx ,
args : slice ::from_raw_parts ( argv , argc as usize ) ,
} ;
( * boxed_aggr ) . inverse ( & mut ctx , & mut * * pac )
} ) ;
let r = match r {
Err ( _ ) = > {
report_error ( ctx , & Error ::UnwindingPanic ) ;
return ;
}
Ok ( r ) = > r ,
} ;
match r {
Ok ( _ ) = > { }
Err ( err ) = > report_error ( ctx , & err ) ,
} ;
}
unsafe extern " C " fn call_boxed_final < A , D , T > ( ctx : * mut sqlite3_context )
where
A : RefUnwindSafe + UnwindSafe ,
D : Aggregate < A , T > ,
T : ToSql ,
{
// Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a : Option < A > = match aggregate_context ( ctx , 0 ) {
Some ( pac ) = > {
if ( * pac as * mut A ) . is_null ( ) {
None
} else {
let a = Box ::from_raw ( * pac ) ;
Some ( * a )
}
}
None = > None ,
} ;
let r = catch_unwind ( | | {
let boxed_aggr : * mut D = ffi ::sqlite3_user_data ( ctx ) as * mut D ;
assert! (
! boxed_aggr . is_null ( ) ,
" Internal error - null aggregate pointer "
) ;
2020-12-18 20:52:07 +08:00
let mut ctx = Context { ctx , args : & mut [ ] } ;
2020-12-18 20:10:35 +08:00
( * boxed_aggr ) . finalize ( & mut ctx , a )
2019-06-26 02:33:49 +08:00
} ) ;
let t = match r {
Err ( _ ) = > {
report_error ( ctx , & Error ::UnwindingPanic ) ;
return ;
}
Ok ( r ) = > r ,
} ;
let t = t . as_ref ( ) . map ( | t | ToSql ::to_sql ( t ) ) ;
match t {
Ok ( Ok ( ref value ) ) = > set_result ( ctx , value ) ,
Ok ( Err ( err ) ) = > report_error ( ctx , & err ) ,
Err ( err ) = > report_error ( ctx , err ) ,
}
}
#[ cfg(feature = " window " ) ]
unsafe extern " C " fn call_boxed_value < A , W , T > ( ctx : * mut sqlite3_context )
where
A : RefUnwindSafe + UnwindSafe ,
W : WindowAggregate < A , T > ,
T : ToSql ,
{
// Within the xValue callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a : Option < & A > = match aggregate_context ( ctx , 0 ) {
Some ( pac ) = > {
if ( * pac as * mut A ) . is_null ( ) {
None
} else {
let a = & * * pac ;
Some ( a )
}
}
None = > None ,
} ;
let r = catch_unwind ( | | {
let boxed_aggr : * mut W = ffi ::sqlite3_user_data ( ctx ) as * mut W ;
assert! (
! boxed_aggr . is_null ( ) ,
" Internal error - null aggregate pointer "
) ;
( * boxed_aggr ) . value ( a )
} ) ;
let t = match r {
Err ( _ ) = > {
report_error ( ctx , & Error ::UnwindingPanic ) ;
return ;
}
Ok ( r ) = > r ,
} ;
let t = t . as_ref ( ) . map ( | t | ToSql ::to_sql ( t ) ) ;
match t {
Ok ( Ok ( ref value ) ) = > set_result ( ctx , value ) ,
Ok ( Err ( err ) ) = > report_error ( ctx , & err ) ,
Err ( err ) = > report_error ( ctx , err ) ,
}
}
2015-08-09 15:52:53 +08:00
#[ cfg(test) ]
mod test {
2019-08-17 14:18:37 +08:00
use regex ::Regex ;
2016-03-30 02:18:56 +08:00
use std ::f64 ::EPSILON ;
2018-08-11 18:48:21 +08:00
use std ::os ::raw ::c_double ;
2015-08-09 19:06:23 +08:00
2019-06-26 02:47:32 +08:00
#[ cfg(feature = " window " ) ]
use crate ::functions ::WindowAggregate ;
2020-01-27 01:11:11 +08:00
use crate ::functions ::{ Aggregate , Context , FunctionFlags } ;
2020-11-03 17:32:46 +08:00
use crate ::{ Connection , Error , Result } ;
2015-08-09 15:52:53 +08:00
2018-12-08 04:57:04 +08:00
fn half ( ctx : & Context < '_ > ) -> Result < c_double > {
2019-02-02 18:09:00 +08:00
assert_eq! ( ctx . len ( ) , 1 , " called with unexpected number of arguments " ) ;
2018-10-31 03:11:35 +08:00
let value = ctx . get ::< c_double > ( 0 ) ? ;
2015-12-12 04:08:40 +08:00
Ok ( value / 2 f64 )
2015-08-09 15:52:53 +08:00
}
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_function_half ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_scalar_function (
" half " ,
1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
half ,
2020-11-06 05:14:00 +08:00
) ? ;
2020-11-03 17:32:46 +08:00
let result : Result < f64 > = db . query_row ( " SELECT half(6) " , [ ] , | r | r . get ( 0 ) ) ;
2015-08-09 15:52:53 +08:00
2020-11-06 05:14:00 +08:00
assert! ( ( 3 f64 - result ? ) . abs ( ) < EPSILON ) ;
Ok ( ( ) )
2015-08-09 15:52:53 +08:00
}
2015-08-09 19:06:23 +08:00
2015-12-12 04:47:52 +08:00
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_remove_function ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_scalar_function (
" half " ,
1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
half ,
2020-11-06 05:14:00 +08:00
) ? ;
2020-11-03 17:32:46 +08:00
let result : Result < f64 > = db . query_row ( " SELECT half(6) " , [ ] , | r | r . get ( 0 ) ) ;
2020-11-06 05:14:00 +08:00
assert! ( ( 3 f64 - result ? ) . abs ( ) < EPSILON ) ;
2015-12-12 04:47:52 +08:00
2020-11-06 05:14:00 +08:00
db . remove_function ( " half " , 1 ) ? ;
2020-11-03 17:32:46 +08:00
let result : Result < f64 > = db . query_row ( " SELECT half(6) " , [ ] , | r | r . get ( 0 ) ) ;
2015-12-12 04:47:52 +08:00
assert! ( result . is_err ( ) ) ;
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2015-12-12 04:47:52 +08:00
}
2021-05-02 19:46:04 +08:00
// This implementation of a regexp scalar function uses SQLite's auxiliary data
2015-12-12 04:35:59 +08:00
// (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
// expression multiple times within one query.
2018-12-08 04:57:04 +08:00
fn regexp_with_auxilliary ( ctx : & Context < '_ > ) -> Result < bool > {
2019-02-02 18:09:00 +08:00
assert_eq! ( ctx . len ( ) , 2 , " called with unexpected number of arguments " ) ;
2020-04-13 10:41:01 +08:00
type BoxError = Box < dyn std ::error ::Error + Send + Sync + 'static > ;
let regexp : std ::sync ::Arc < Regex > = ctx
. get_or_create_aux ( 0 , | vr | -> Result < _ , BoxError > {
Ok ( Regex ::new ( vr . as_str ( ) ? ) ? )
} ) ? ;
2015-08-09 19:06:23 +08:00
2015-12-12 04:08:40 +08:00
let is_match = {
2018-10-28 15:51:02 +08:00
let text = ctx
. get_raw ( 1 )
. as_str ( )
. map_err ( | e | Error ::UserFunctionError ( e . into ( ) ) ) ? ;
2018-10-19 02:59:30 +08:00
2020-04-13 10:41:01 +08:00
regexp . is_match ( text )
2015-12-12 04:08:40 +08:00
} ;
2015-08-09 19:06:23 +08:00
2015-12-12 04:08:40 +08:00
Ok ( is_match )
2015-08-09 19:06:23 +08:00
}
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_function_regexp_with_auxilliary ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2018-08-11 18:48:21 +08:00
db . execute_batch (
" BEGIN;
CREATE TABLE foo ( x string ) ;
INSERT INTO foo VALUES ( ' lisa ' ) ;
INSERT INTO foo VALUES ( ' lXsi ' ) ;
INSERT INTO foo VALUES ( ' lisX ' ) ;
END ; " ,
2020-11-06 05:14:00 +08:00
) ? ;
2020-01-27 01:11:11 +08:00
db . create_scalar_function (
" regexp " ,
2 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
regexp_with_auxilliary ,
2020-11-06 05:14:00 +08:00
) ? ;
2018-08-11 18:48:21 +08:00
2018-09-16 17:10:19 +08:00
let result : Result < bool > =
2020-11-03 17:32:46 +08:00
db . query_row ( " SELECT regexp('l.s[aeiouy]', 'lisa') " , [ ] , | r | r . get ( 0 ) ) ;
2015-12-12 04:35:59 +08:00
2021-05-13 14:58:46 +08:00
assert! ( result ? ) ;
2015-12-12 04:35:59 +08:00
2018-08-11 18:48:21 +08:00
let result : Result < i64 > = db . query_row (
" SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1 " ,
2020-11-03 17:32:46 +08:00
[ ] ,
2018-08-11 18:48:21 +08:00
| r | r . get ( 0 ) ,
) ;
2015-12-12 03:46:28 +08:00
2020-11-06 05:14:00 +08:00
assert_eq! ( 2 , result ? ) ;
Ok ( ( ) )
2015-08-09 19:06:23 +08:00
}
2015-12-12 23:44:08 +08:00
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_varargs_function ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_scalar_function (
" my_concat " ,
- 1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
| ctx | {
let mut ret = String ::new ( ) ;
for idx in 0 .. ctx . len ( ) {
let s = ctx . get ::< String > ( idx ) ? ;
ret . push_str ( & s ) ;
}
2015-12-12 23:44:08 +08:00
2020-01-27 01:11:11 +08:00
Ok ( ret )
} ,
2020-11-06 05:14:00 +08:00
) ? ;
2015-12-12 23:44:08 +08:00
2018-08-11 18:48:21 +08:00
for & ( expected , query ) in & [
( " " , " SELECT my_concat() " ) ,
( " onetwo " , " SELECT my_concat('one', 'two') " ) ,
( " abc " , " SELECT my_concat('a', 'b', 'c') " ) ,
] {
2020-11-06 05:14:00 +08:00
let result : String = db . query_row ( query , [ ] , | r | r . get ( 0 ) ) ? ;
2015-12-12 23:44:08 +08:00
assert_eq! ( expected , result ) ;
}
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2015-12-12 23:44:08 +08:00
}
2015-12-20 19:23:51 +08:00
2019-01-25 16:43:50 +08:00
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_get_aux_type_checking ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_scalar_function ( " example " , 2 , FunctionFlags ::default ( ) , | ctx | {
2019-01-25 16:43:50 +08:00
if ! ctx . get ::< bool > ( 1 ) ? {
2020-04-13 10:41:01 +08:00
ctx . set_aux ::< i64 > ( 0 , 100 ) ? ;
2019-01-25 16:43:50 +08:00
} else {
assert_eq! ( ctx . get_aux ::< String > ( 0 ) , Err ( Error ::GetAuxWrongType ) ) ;
2020-11-06 05:14:00 +08:00
assert_eq! ( * ctx . get_aux ::< i64 > ( 0 ) ? . unwrap ( ) , 100 ) ;
2019-01-25 16:43:50 +08:00
}
Ok ( true )
2020-11-06 05:14:00 +08:00
} ) ? ;
2019-01-25 16:43:50 +08:00
2020-11-06 05:14:00 +08:00
let res : bool = db . query_row (
" SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1) " ,
[ ] ,
| r | r . get ( 0 ) ,
) ? ;
2019-01-25 16:43:50 +08:00
// Doesn't actually matter, we'll assert in the function if there's a problem.
assert! ( res ) ;
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2019-01-25 16:43:50 +08:00
}
2015-12-20 19:23:51 +08:00
struct Sum ;
2016-01-08 01:33:32 +08:00
struct Count ;
2016-01-08 04:14:24 +08:00
impl Aggregate < i64 , Option < i64 > > for Sum {
2020-12-19 19:12:30 +08:00
fn init ( & self , _ : & mut Context < '_ > ) -> Result < i64 > {
Ok ( 0 )
2016-01-08 01:33:32 +08:00
}
2018-12-08 04:57:04 +08:00
fn step ( & self , ctx : & mut Context < '_ > , sum : & mut i64 ) -> Result < ( ) > {
2018-10-31 03:11:35 +08:00
* sum + = ctx . get ::< i64 > ( 0 ) ? ;
2016-01-08 01:33:32 +08:00
Ok ( ( ) )
}
2020-12-18 20:53:13 +08:00
fn finalize ( & self , _ : & mut Context < '_ > , sum : Option < i64 > ) -> Result < Option < i64 > > {
2016-01-08 01:33:32 +08:00
Ok ( sum )
}
}
2015-12-20 19:23:51 +08:00
2016-01-08 01:33:32 +08:00
impl Aggregate < i64 , i64 > for Count {
2020-12-19 19:12:30 +08:00
fn init ( & self , _ : & mut Context < '_ > ) -> Result < i64 > {
Ok ( 0 )
2015-12-20 19:23:51 +08:00
}
2018-12-08 04:57:04 +08:00
fn step ( & self , _ctx : & mut Context < '_ > , sum : & mut i64 ) -> Result < ( ) > {
2016-01-08 01:33:32 +08:00
* sum + = 1 ;
2015-12-20 19:23:51 +08:00
Ok ( ( ) )
}
2020-12-18 20:53:13 +08:00
fn finalize ( & self , _ : & mut Context < '_ > , sum : Option < i64 > ) -> Result < i64 > {
2016-01-08 04:14:24 +08:00
Ok ( sum . unwrap_or ( 0 ) )
2015-12-20 19:23:51 +08:00
}
}
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_sum ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_aggregate_function (
" my_sum " ,
1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
Sum ,
2020-11-06 05:14:00 +08:00
) ? ;
2016-01-08 01:33:32 +08:00
// sum should return NULL when given no columns (contrast with count below)
2015-12-20 19:23:51 +08:00
let no_result = " SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1) " ;
2020-11-06 05:14:00 +08:00
let result : Option < i64 > = db . query_row ( no_result , [ ] , | r | r . get ( 0 ) ) ? ;
2015-12-20 19:23:51 +08:00
assert! ( result . is_none ( ) ) ;
let single_sum = " SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2) " ;
2020-11-06 05:14:00 +08:00
let result : i64 = db . query_row ( single_sum , [ ] , | r | r . get ( 0 ) ) ? ;
2015-12-20 19:23:51 +08:00
assert_eq! ( 4 , result ) ;
2015-12-21 02:27:28 +08:00
let dual_sum = " SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
2 , 1 ) " ;
2020-11-06 05:14:00 +08:00
let result : ( i64 , i64 ) = db . query_row ( dual_sum , [ ] , | r | Ok ( ( r . get ( 0 ) ? , r . get ( 1 ) ? ) ) ) ? ;
2015-12-20 19:23:51 +08:00
assert_eq! ( ( 4 , 2 ) , result ) ;
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2015-12-20 19:23:51 +08:00
}
2016-01-08 01:33:32 +08:00
#[ test ]
2020-11-06 05:14:00 +08:00
fn test_count ( ) -> Result < ( ) > {
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_aggregate_function (
" my_count " ,
- 1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
Count ,
2020-11-06 05:14:00 +08:00
) ? ;
2016-01-08 01:33:32 +08:00
// count should return 0 when given no columns (contrast with sum above)
let no_result = " SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1) " ;
2020-11-06 05:14:00 +08:00
let result : i64 = db . query_row ( no_result , [ ] , | r | r . get ( 0 ) ) ? ;
2016-01-08 01:33:32 +08:00
assert_eq! ( result , 0 ) ;
let single_sum = " SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2) " ;
2020-11-06 05:14:00 +08:00
let result : i64 = db . query_row ( single_sum , [ ] , | r | r . get ( 0 ) ) ? ;
2016-01-08 01:33:32 +08:00
assert_eq! ( 2 , result ) ;
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2016-01-08 01:33:32 +08:00
}
2019-06-26 02:33:49 +08:00
#[ cfg(feature = " window " ) ]
impl WindowAggregate < i64 , Option < i64 > > for Sum {
fn inverse ( & self , ctx : & mut Context < '_ > , sum : & mut i64 ) -> Result < ( ) > {
* sum - = ctx . get ::< i64 > ( 0 ) ? ;
Ok ( ( ) )
}
fn value ( & self , sum : Option < & i64 > ) -> Result < Option < i64 > > {
Ok ( sum . copied ( ) )
}
}
#[ test ]
#[ cfg(feature = " window " ) ]
2020-11-06 05:14:00 +08:00
fn test_window ( ) -> Result < ( ) > {
2019-06-26 02:33:49 +08:00
use fallible_iterator ::FallibleIterator ;
2020-11-06 05:14:00 +08:00
let db = Connection ::open_in_memory ( ) ? ;
2020-01-27 01:11:11 +08:00
db . create_window_function (
" sumint " ,
1 ,
FunctionFlags ::SQLITE_UTF8 | FunctionFlags ::SQLITE_DETERMINISTIC ,
Sum ,
2020-11-06 05:14:00 +08:00
) ? ;
2019-06-26 02:33:49 +08:00
db . execute_batch (
" CREATE TABLE t3(x, y);
INSERT INTO t3 VALUES ( 'a' , 4 ) ,
( 'b' , 5 ) ,
( 'c' , 3 ) ,
( 'd' , 8 ) ,
( 'e' , 1 ) ; " ,
2020-11-06 05:14:00 +08:00
) ? ;
2019-06-26 02:33:49 +08:00
2020-11-06 05:14:00 +08:00
let mut stmt = db . prepare (
" SELECT x, sumint(y) OVER (
2019-06-26 02:33:49 +08:00
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM t3 ORDER BY x ; " ,
2020-11-06 05:14:00 +08:00
) ? ;
2019-06-26 02:33:49 +08:00
let results : Vec < ( String , i64 ) > = stmt
2020-11-06 05:14:00 +08:00
. query ( [ ] ) ?
2019-06-26 02:33:49 +08:00
. map ( | row | Ok ( ( row . get ( " x " ) ? , row . get ( " sum_y " ) ? ) ) )
2020-11-06 05:14:00 +08:00
. collect ( ) ? ;
2019-06-26 02:33:49 +08:00
let expected = vec! [
( " a " . to_owned ( ) , 9 ) ,
( " b " . to_owned ( ) , 12 ) ,
( " c " . to_owned ( ) , 16 ) ,
( " d " . to_owned ( ) , 12 ) ,
( " e " . to_owned ( ) , 9 ) ,
] ;
assert_eq! ( expected , results ) ;
2020-11-06 05:14:00 +08:00
Ok ( ( ) )
2019-06-26 02:33:49 +08:00
}
2015-12-07 04:33:21 +08:00
}