Add support to function flags (#622)

Breaking changes
This commit is contained in:
gwenn 2020-01-26 18:11:11 +01:00 committed by GitHub
parent b61c570cdd
commit 5565d2e058
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 144 additions and 79 deletions

View File

@ -10,41 +10,47 @@
//! //!
//! ```rust //! ```rust
//! use regex::Regex; //! use regex::Regex;
//! use rusqlite::functions::FunctionFlags;
//! use rusqlite::{Connection, Error, Result, NO_PARAMS}; //! use rusqlite::{Connection, Error, Result, NO_PARAMS};
//! //!
//! fn add_regexp_function(db: &Connection) -> Result<()> { //! fn add_regexp_function(db: &Connection) -> Result<()> {
//! db.create_scalar_function("regexp", 2, true, move |ctx| { //! db.create_scalar_function(
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments"); //! "regexp",
//! 2,
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
//! //!
//! let saved_re: Option<&Regex> = ctx.get_aux(0)?; //! let saved_re: Option<&Regex> = ctx.get_aux(0)?;
//! let new_re = match saved_re { //! let new_re = match saved_re {
//! None => { //! None => {
//! let s = ctx.get::<String>(0)?; //! let s = ctx.get::<String>(0)?;
//! match Regex::new(&s) { //! match Regex::new(&s) {
//! Ok(r) => Some(r), //! Ok(r) => Some(r),
//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))), //! Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
//! }
//! } //! }
//! Some(_) => None,
//! };
//!
//! let is_match = {
//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
//!
//! let text = ctx
//! .get_raw(1)
//! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?;
//!
//! re.is_match(text)
//! };
//!
//! if let Some(re) = new_re {
//! ctx.set_aux(0, re);
//! } //! }
//! Some(_) => None,
//! };
//! //!
//! let is_match = { //! Ok(is_match)
//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); //! },
//! //! )
//! let text = ctx
//! .get_raw(1)
//! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?;
//!
//! re.is_match(text)
//! };
//!
//! if let Some(re) = new_re {
//! ctx.set_aux(0, re);
//! }
//!
//! Ok(is_match)
//! })
//! } //! }
//! //!
//! fn main() -> Result<()> { //! fn main() -> Result<()> {
@ -241,6 +247,28 @@ where
fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
} }
bitflags::bitflags! {
#[doc = "Function Flags."]
#[doc = "See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) for details."]
#[repr(C)]
pub struct FunctionFlags: ::std::os::raw::c_int {
const SQLITE_UTF8 = ffi::SQLITE_UTF8;
const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
const SQLITE_UTF16 = ffi::SQLITE_UTF16;
const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0
const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
}
}
impl Default for FunctionFlags {
fn default() -> FunctionFlags {
FunctionFlags::SQLITE_UTF8
}
}
impl Connection { impl Connection {
/// Attach a user-defined scalar function to this database connection. /// Attach a user-defined scalar function to this database connection.
/// ///
@ -256,11 +284,17 @@ impl Connection {
/// ///
/// ```rust /// ```rust
/// # use rusqlite::{Connection, Result, NO_PARAMS}; /// # use rusqlite::{Connection, Result, NO_PARAMS};
/// # use rusqlite::functions::FunctionFlags;
/// fn scalar_function_example(db: Connection) -> Result<()> { /// fn scalar_function_example(db: Connection) -> Result<()> {
/// db.create_scalar_function("halve", 1, true, |ctx| { /// db.create_scalar_function(
/// let value = ctx.get::<f64>(0)?; /// "halve",
/// Ok(value / 2f64) /// 1,
/// })?; /// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
/// |ctx| {
/// let value = ctx.get::<f64>(0)?;
/// Ok(value / 2f64)
/// },
/// )?;
/// ///
/// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?; /// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
/// assert_eq!(six_halved, 3f64); /// assert_eq!(six_halved, 3f64);
@ -275,7 +309,7 @@ impl Connection {
&self, &self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
@ -284,7 +318,7 @@ impl Connection {
{ {
self.db self.db
.borrow_mut() .borrow_mut()
.create_scalar_function(fn_name, n_arg, deterministic, x_func) .create_scalar_function(fn_name, n_arg, flags, x_func)
} }
/// Attach a user-defined aggregate function to this database connection. /// Attach a user-defined aggregate function to this database connection.
@ -296,7 +330,7 @@ impl Connection {
&self, &self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
aggr: D, aggr: D,
) -> Result<()> ) -> Result<()>
where where
@ -306,7 +340,7 @@ impl Connection {
{ {
self.db self.db
.borrow_mut() .borrow_mut()
.create_aggregate_function(fn_name, n_arg, deterministic, aggr) .create_aggregate_function(fn_name, n_arg, flags, aggr)
} }
#[cfg(feature = "window")] #[cfg(feature = "window")]
@ -314,7 +348,7 @@ impl Connection {
&self, &self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
aggr: W, aggr: W,
) -> Result<()> ) -> Result<()>
where where
@ -324,7 +358,7 @@ impl Connection {
{ {
self.db self.db
.borrow_mut() .borrow_mut()
.create_window_function(fn_name, n_arg, deterministic, aggr) .create_window_function(fn_name, n_arg, flags, aggr)
} }
/// Removes a user-defined function from this database connection. /// Removes a user-defined function from this database connection.
@ -345,7 +379,7 @@ impl InnerConnection {
&mut self, &mut self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
x_func: F, x_func: F,
) -> Result<()> ) -> Result<()>
where where
@ -387,16 +421,12 @@ impl InnerConnection {
let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
let c_name = str_to_cstring(fn_name)?; let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe { let r = unsafe {
ffi::sqlite3_create_function_v2( ffi::sqlite3_create_function_v2(
self.db(), self.db(),
c_name.as_ptr(), c_name.as_ptr(),
n_arg, n_arg,
flags, flags.bits(),
boxed_f as *mut c_void, boxed_f as *mut c_void,
Some(call_boxed_closure::<F, T>), Some(call_boxed_closure::<F, T>),
None, None,
@ -411,7 +441,7 @@ impl InnerConnection {
&mut self, &mut self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
aggr: D, aggr: D,
) -> Result<()> ) -> Result<()>
where where
@ -421,16 +451,12 @@ impl InnerConnection {
{ {
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?; let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe { let r = unsafe {
ffi::sqlite3_create_function_v2( ffi::sqlite3_create_function_v2(
self.db(), self.db(),
c_name.as_ptr(), c_name.as_ptr(),
n_arg, n_arg,
flags, flags.bits(),
boxed_aggr as *mut c_void, boxed_aggr as *mut c_void,
None, None,
Some(call_boxed_step::<A, D, T>), Some(call_boxed_step::<A, D, T>),
@ -446,7 +472,7 @@ impl InnerConnection {
&mut self, &mut self,
fn_name: &str, fn_name: &str,
n_arg: c_int, n_arg: c_int,
deterministic: bool, flags: FunctionFlags,
aggr: W, aggr: W,
) -> Result<()> ) -> Result<()>
where where
@ -456,16 +482,12 @@ impl InnerConnection {
{ {
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?; let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe { let r = unsafe {
ffi::sqlite3_create_window_function( ffi::sqlite3_create_window_function(
self.db(), self.db(),
c_name.as_ptr(), c_name.as_ptr(),
n_arg, n_arg,
flags, flags.bits(),
boxed_aggr as *mut c_void, boxed_aggr as *mut c_void,
Some(call_boxed_step::<A, W, T>), Some(call_boxed_step::<A, W, T>),
Some(call_boxed_final::<A, W, T>), Some(call_boxed_final::<A, W, T>),
@ -687,7 +709,7 @@ mod test {
#[cfg(feature = "window")] #[cfg(feature = "window")]
use crate::functions::WindowAggregate; use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context}; use crate::functions::{Aggregate, Context, FunctionFlags};
use crate::{Connection, Error, Result, NO_PARAMS}; use crate::{Connection, Error, Result, NO_PARAMS};
fn half(ctx: &Context<'_>) -> Result<c_double> { fn half(ctx: &Context<'_>) -> Result<c_double> {
@ -699,7 +721,13 @@ mod test {
#[test] #[test]
fn test_function_half() { fn test_function_half() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("half", 1, true, half).unwrap(); db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
)
.unwrap();
let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
assert!((3f64 - result.unwrap()).abs() < EPSILON); assert!((3f64 - result.unwrap()).abs() < EPSILON);
@ -708,7 +736,13 @@ mod test {
#[test] #[test]
fn test_remove_function() { fn test_remove_function() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("half", 1, true, half).unwrap(); db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
)
.unwrap();
let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0)); let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
assert!((3f64 - result.unwrap()).abs() < EPSILON); assert!((3f64 - result.unwrap()).abs() < EPSILON);
@ -765,8 +799,13 @@ mod test {
END;", END;",
) )
.unwrap(); .unwrap();
db.create_scalar_function("regexp", 2, true, regexp_with_auxilliary) db.create_scalar_function(
.unwrap(); "regexp",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
regexp_with_auxilliary,
)
.unwrap();
let result: Result<bool> = let result: Result<bool> =
db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| { db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
@ -787,16 +826,21 @@ mod test {
#[test] #[test]
fn test_varargs_function() { fn test_varargs_function() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("my_concat", -1, true, |ctx| { db.create_scalar_function(
let mut ret = String::new(); "my_concat",
-1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|ctx| {
let mut ret = String::new();
for idx in 0..ctx.len() { for idx in 0..ctx.len() {
let s = ctx.get::<String>(idx)?; let s = ctx.get::<String>(idx)?;
ret.push_str(&s); ret.push_str(&s);
} }
Ok(ret) Ok(ret)
}) },
)
.unwrap(); .unwrap();
for &(expected, query) in &[ for &(expected, query) in &[
@ -812,7 +856,7 @@ mod test {
#[test] #[test]
fn test_get_aux_type_checking() { fn test_get_aux_type_checking() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("example", 2, false, |ctx| { db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
if !ctx.get::<bool>(1)? { if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100); ctx.set_aux::<i64>(0, 100);
} else { } else {
@ -870,8 +914,13 @@ mod test {
#[test] #[test]
fn test_sum() { fn test_sum() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_aggregate_function("my_sum", 1, true, Sum) db.create_aggregate_function(
.unwrap(); "my_sum",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
)
.unwrap();
// sum should return NULL when given no columns (contrast with count below) // sum should return NULL when given no columns (contrast with count below)
let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
@ -893,8 +942,13 @@ mod test {
#[test] #[test]
fn test_count() { fn test_count() {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_aggregate_function("my_count", -1, true, Count) db.create_aggregate_function(
.unwrap(); "my_count",
-1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Count,
)
.unwrap();
// count should return 0 when given no columns (contrast with sum above) // count should return 0 when given no columns (contrast with sum above)
let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
@ -924,7 +978,13 @@ mod test {
use fallible_iterator::FallibleIterator; use fallible_iterator::FallibleIterator;
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
db.create_window_function("sumint", 1, true, Sum).unwrap(); db.create_window_function(
"sumint",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
)
.unwrap();
db.execute_batch( db.execute_batch(
"CREATE TABLE t3(x, y); "CREATE TABLE t3(x, y);
INSERT INTO t3 VALUES('a', 4), INSERT INTO t3 VALUES('a', 4),

View File

@ -1370,10 +1370,15 @@ mod test {
let interrupt_handle = db.get_interrupt_handle(); let interrupt_handle = db.get_interrupt_handle();
db.create_scalar_function("interrupt", 0, false, move |_| { db.create_scalar_function(
interrupt_handle.interrupt(); "interrupt",
Ok(0) 0,
}) crate::functions::FunctionFlags::default(),
move |_| {
interrupt_handle.interrupt();
Ok(0)
},
)
.unwrap(); .unwrap();
let mut stmt = db let mut stmt = db