diff --git a/src/functions.rs b/src/functions.rs index 7a30255..f949559 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -10,41 +10,47 @@ //! //! ```rust //! use regex::Regex; +//! use rusqlite::functions::FunctionFlags; //! use rusqlite::{Connection, Error, Result, NO_PARAMS}; //! //! fn add_regexp_function(db: &Connection) -> Result<()> { -//! db.create_scalar_function("regexp", 2, true, move |ctx| { -//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); +//! db.create_scalar_function( +//! "regexp", +//! 2, +//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, +//! move |ctx| { +//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); //! -//! let saved_re: Option<&Regex> = ctx.get_aux(0)?; -//! let new_re = match saved_re { -//! None => { -//! let s = ctx.get::(0)?; -//! match Regex::new(&s) { -//! Ok(r) => Some(r), -//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))), +//! let saved_re: Option<&Regex> = ctx.get_aux(0)?; +//! let new_re = match saved_re { +//! None => { +//! let s = ctx.get::(0)?; +//! match Regex::new(&s) { +//! Ok(r) => Some(r), +//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))), +//! } //! } +//! Some(_) => None, +//! }; +//! +//! let is_match = { +//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); +//! +//! let text = ctx +//! .get_raw(1) +//! .as_str() +//! .map_err(|e| Error::UserFunctionError(e.into()))?; +//! +//! re.is_match(text) +//! }; +//! +//! if let Some(re) = new_re { +//! ctx.set_aux(0, re); //! } -//! Some(_) => None, -//! }; //! -//! let is_match = { -//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); -//! -//! let text = ctx -//! .get_raw(1) -//! .as_str() -//! .map_err(|e| Error::UserFunctionError(e.into()))?; -//! -//! re.is_match(text) -//! }; -//! -//! if let Some(re) = new_re { -//! ctx.set_aux(0, re); -//! } -//! -//! Ok(is_match) -//! }) +//! Ok(is_match) +//! }, +//! ) //! } //! //! fn main() -> Result<()> { @@ -241,6 +247,28 @@ where fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; } +bitflags::bitflags! { + #[doc = "Function Flags."] + #[doc = "See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) for details."] + #[repr(C)] + pub struct FunctionFlags: ::std::os::raw::c_int { + const SQLITE_UTF8 = ffi::SQLITE_UTF8; + const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE; + const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE; + const SQLITE_UTF16 = ffi::SQLITE_UTF16; + const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; + const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0 + const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0 + const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0 + } +} + +impl Default for FunctionFlags { + fn default() -> FunctionFlags { + FunctionFlags::SQLITE_UTF8 + } +} + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -256,11 +284,17 @@ impl Connection { /// /// ```rust /// # use rusqlite::{Connection, Result, NO_PARAMS}; + /// # use rusqlite::functions::FunctionFlags; /// fn scalar_function_example(db: Connection) -> Result<()> { - /// db.create_scalar_function("halve", 1, true, |ctx| { - /// let value = ctx.get::(0)?; - /// Ok(value / 2f64) - /// })?; + /// db.create_scalar_function( + /// "halve", + /// 1, + /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + /// |ctx| { + /// let value = ctx.get::(0)?; + /// Ok(value / 2f64) + /// }, + /// )?; /// /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?; /// assert_eq!(six_halved, 3f64); @@ -275,7 +309,7 @@ impl Connection { &self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, x_func: F, ) -> Result<()> where @@ -284,7 +318,7 @@ impl Connection { { self.db .borrow_mut() - .create_scalar_function(fn_name, n_arg, deterministic, x_func) + .create_scalar_function(fn_name, n_arg, flags, x_func) } /// Attach a user-defined aggregate function to this database connection. @@ -296,7 +330,7 @@ impl Connection { &self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, aggr: D, ) -> Result<()> where @@ -306,7 +340,7 @@ impl Connection { { self.db .borrow_mut() - .create_aggregate_function(fn_name, n_arg, deterministic, aggr) + .create_aggregate_function(fn_name, n_arg, flags, aggr) } #[cfg(feature = "window")] @@ -314,7 +348,7 @@ impl Connection { &self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, aggr: W, ) -> Result<()> where @@ -324,7 +358,7 @@ impl Connection { { self.db .borrow_mut() - .create_window_function(fn_name, n_arg, deterministic, aggr) + .create_window_function(fn_name, n_arg, flags, aggr) } /// Removes a user-defined function from this database connection. @@ -345,7 +379,7 @@ impl InnerConnection { &mut self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, x_func: F, ) -> Result<()> where @@ -387,16 +421,12 @@ impl InnerConnection { let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let c_name = str_to_cstring(fn_name)?; - let mut flags = ffi::SQLITE_UTF8; - if deterministic { - flags |= ffi::SQLITE_DETERMINISTIC; - } let r = unsafe { ffi::sqlite3_create_function_v2( self.db(), c_name.as_ptr(), n_arg, - flags, + flags.bits(), boxed_f as *mut c_void, Some(call_boxed_closure::), None, @@ -411,7 +441,7 @@ impl InnerConnection { &mut self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, aggr: D, ) -> Result<()> where @@ -421,16 +451,12 @@ impl InnerConnection { { let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; - let mut flags = ffi::SQLITE_UTF8; - if deterministic { - flags |= ffi::SQLITE_DETERMINISTIC; - } let r = unsafe { ffi::sqlite3_create_function_v2( self.db(), c_name.as_ptr(), n_arg, - flags, + flags.bits(), boxed_aggr as *mut c_void, None, Some(call_boxed_step::), @@ -446,7 +472,7 @@ impl InnerConnection { &mut self, fn_name: &str, n_arg: c_int, - deterministic: bool, + flags: FunctionFlags, aggr: W, ) -> Result<()> where @@ -456,16 +482,12 @@ impl InnerConnection { { let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; - let mut flags = ffi::SQLITE_UTF8; - if deterministic { - flags |= ffi::SQLITE_DETERMINISTIC; - } let r = unsafe { ffi::sqlite3_create_window_function( self.db(), c_name.as_ptr(), n_arg, - flags, + flags.bits(), boxed_aggr as *mut c_void, Some(call_boxed_step::), Some(call_boxed_final::), @@ -687,7 +709,7 @@ mod test { #[cfg(feature = "window")] use crate::functions::WindowAggregate; - use crate::functions::{Aggregate, Context}; + use crate::functions::{Aggregate, Context, FunctionFlags}; use crate::{Connection, Error, Result, NO_PARAMS}; fn half(ctx: &Context<'_>) -> Result { @@ -699,7 +721,13 @@ mod test { #[test] fn test_function_half() { let db = Connection::open_in_memory().unwrap(); - db.create_scalar_function("half", 1, true, half).unwrap(); + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + ) + .unwrap(); let result: Result = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); assert!((3f64 - result.unwrap()).abs() < EPSILON); @@ -708,7 +736,13 @@ mod test { #[test] fn test_remove_function() { let db = Connection::open_in_memory().unwrap(); - db.create_scalar_function("half", 1, true, half).unwrap(); + db.create_scalar_function( + "half", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + half, + ) + .unwrap(); let result: Result = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); assert!((3f64 - result.unwrap()).abs() < EPSILON); @@ -765,8 +799,13 @@ mod test { END;", ) .unwrap(); - db.create_scalar_function("regexp", 2, true, regexp_with_auxilliary) - .unwrap(); + db.create_scalar_function( + "regexp", + 2, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + regexp_with_auxilliary, + ) + .unwrap(); let result: Result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| { @@ -787,16 +826,21 @@ mod test { #[test] fn test_varargs_function() { let db = Connection::open_in_memory().unwrap(); - db.create_scalar_function("my_concat", -1, true, |ctx| { - let mut ret = String::new(); + db.create_scalar_function( + "my_concat", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + |ctx| { + let mut ret = String::new(); - for idx in 0..ctx.len() { - let s = ctx.get::(idx)?; - ret.push_str(&s); - } + for idx in 0..ctx.len() { + let s = ctx.get::(idx)?; + ret.push_str(&s); + } - Ok(ret) - }) + Ok(ret) + }, + ) .unwrap(); for &(expected, query) in &[ @@ -812,7 +856,7 @@ mod test { #[test] fn test_get_aux_type_checking() { let db = Connection::open_in_memory().unwrap(); - db.create_scalar_function("example", 2, false, |ctx| { + db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| { if !ctx.get::(1)? { ctx.set_aux::(0, 100); } else { @@ -870,8 +914,13 @@ mod test { #[test] fn test_sum() { let db = Connection::open_in_memory().unwrap(); - db.create_aggregate_function("my_sum", 1, true, Sum) - .unwrap(); + db.create_aggregate_function( + "my_sum", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + ) + .unwrap(); // sum should return NULL when given no columns (contrast with count below) let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; @@ -893,8 +942,13 @@ mod test { #[test] fn test_count() { let db = Connection::open_in_memory().unwrap(); - db.create_aggregate_function("my_count", -1, true, Count) - .unwrap(); + db.create_aggregate_function( + "my_count", + -1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Count, + ) + .unwrap(); // count should return 0 when given no columns (contrast with sum above) let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; @@ -924,7 +978,13 @@ mod test { use fallible_iterator::FallibleIterator; let db = Connection::open_in_memory().unwrap(); - db.create_window_function("sumint", 1, true, Sum).unwrap(); + db.create_window_function( + "sumint", + 1, + FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC, + Sum, + ) + .unwrap(); db.execute_batch( "CREATE TABLE t3(x, y); INSERT INTO t3 VALUES('a', 4), diff --git a/src/lib.rs b/src/lib.rs index f51b5f4..20db262 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1370,10 +1370,15 @@ mod test { let interrupt_handle = db.get_interrupt_handle(); - db.create_scalar_function("interrupt", 0, false, move |_| { - interrupt_handle.interrupt(); - Ok(0) - }) + db.create_scalar_function( + "interrupt", + 0, + crate::functions::FunctionFlags::default(), + move |_| { + interrupt_handle.interrupt(); + Ok(0) + }, + ) .unwrap(); let mut stmt = db