From 0a454eed795020bebd4d8c420171798b921646e3 Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Sun, 9 Aug 2015 09:52:53 +0200 Subject: [PATCH 01/13] Add support to user defined scalar functions --- Cargo.toml | 1 + libsqlite3-sys/src/lib.rs | 3 + src/functions.rs | 250 ++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 4 files changed, 255 insertions(+) create mode 100644 src/functions.rs diff --git a/Cargo.toml b/Cargo.toml index 5d05ed5..3241faf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ name = "rusqlite" [features] load_extension = ["libsqlite3-sys/load_extension"] +functions = [] [dependencies] time = "~0.1.0" diff --git a/libsqlite3-sys/src/lib.rs b/libsqlite3-sys/src/lib.rs index 5d37276..3261a7d 100644 --- a/libsqlite3-sys/src/lib.rs +++ b/libsqlite3-sys/src/lib.rs @@ -92,3 +92,6 @@ pub fn code_to_str(code: c_int) -> &'static str { _ => "Unknown error code", } } + +pub const SQLITE_UTF8 : c_int = 1; +pub const SQLITE_DETERMINISTIC : c_int = 0x800; diff --git a/src/functions.rs b/src/functions.rs new file mode 100644 index 0000000..d96f859 --- /dev/null +++ b/src/functions.rs @@ -0,0 +1,250 @@ +//! Create or redefine SQL functions +use std::ffi::{CStr}; +use std::mem; +use std::ptr; +use std::str; +use libc::{c_int, c_double, c_char}; + +use ffi; +pub use ffi::sqlite3_context as sqlite3_context; +pub use ffi::sqlite3_value as sqlite3_value; +pub use ffi::sqlite3_value_type as sqlite3_value_type; +pub use ffi::sqlite3_value_numeric_type as sqlite3_value_numeric_type; + +use types::Null; + +use {SqliteResult, SqliteError, SqliteConnection, str_to_cstring, InnerSqliteConnection}; + +/// A trait for types that can be converted into the result of an SQL function. +pub trait ToResult { + unsafe fn result(&self, ctx: *mut sqlite3_context); +} + +macro_rules! raw_to_impl( + ($t:ty, $f:ident) => ( + impl ToResult for $t { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + ffi::$f(ctx, *self) + } + } + ) +); + +raw_to_impl!(c_int, sqlite3_result_int); +raw_to_impl!(i64, sqlite3_result_int64); +raw_to_impl!(c_double, sqlite3_result_double); + + +impl<'a> ToResult for &'a str { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + let length = self.len(); + if length > ::std::i32::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + return + } + match str_to_cstring(self) { + Ok(c_str) => ffi::sqlite3_result_text(ctx, c_str.as_ptr(), length as c_int, + ffi::SQLITE_TRANSIENT()), + Err(_) => ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), // TODO sqlite3_result_error + } + } +} + +impl ToResult for String { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + (&self[..]).result(ctx) + } +} + +impl<'a> ToResult for &'a [u8] { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + if self.len() > ::std::i32::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + return + } + ffi::sqlite3_result_blob( + ctx, mem::transmute(self.as_ptr()), self.len() as c_int, ffi::SQLITE_TRANSIENT()) + } +} + +impl ToResult for Vec { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + (&self[..]).result(ctx) + } +} + +impl ToResult for Option { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + match *self { + None => ffi::sqlite3_result_null(ctx), + Some(ref t) => t.result(ctx), + } + } +} + +impl ToResult for Null { + unsafe fn result(&self, ctx: *mut sqlite3_context) { + ffi::sqlite3_result_null(ctx) + } +} + + +// sqlite3_result_error_code, c_int +// sqlite3_result_error_nomem +// sqlite3_result_error_toobig +// sqlite3_result_error, *const c_char, c_int +// sqlite3_result_zeroblob +// sqlite3_result_value + +/// A trait for types that can be created from a SQLite function parameter value. +pub trait FromValue { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult; + + /// FromValue types can implement this method and use sqlite3_value_type to check that + /// the type reported by SQLite matches a type suitable for Self. This method is used + /// by `???` to confirm that the parameter contains a valid type before + /// attempting to retrieve the value. + unsafe fn parameter_has_valid_sqlite_type(_: *mut sqlite3_value) -> bool { + true + } +} + + +macro_rules! raw_from_impl( + ($t:ty, $f:ident, $c:expr) => ( + impl FromValue for $t { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult<$t> { + Ok(ffi::$f(v)) + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_numeric_type(v) == $c + } + } + ) +); + +raw_from_impl!(c_int, sqlite3_value_int, ffi::SQLITE_INTEGER); +raw_from_impl!(i64, sqlite3_value_int64, ffi::SQLITE_INTEGER); + +impl FromValue for c_double { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { + Ok(ffi::sqlite3_value_double(v)) + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_numeric_type(v) == ffi::SQLITE_FLOAT || sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER + } +} + +impl FromValue for String { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { + let c_text = ffi::sqlite3_value_text(v); + if c_text.is_null() { + Ok("".to_string()) + } else { + let c_slice = CStr::from_ptr(c_text as *const c_char).to_bytes(); + let utf8_str = str::from_utf8(c_slice); + utf8_str + .map(|s| { s.to_string() }) + .map_err(|e| { SqliteError{code: 0, message: e.to_string()} }) + } + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_type(v) == ffi::SQLITE_TEXT + } +} + +impl FromValue for Vec { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult> { + use std::slice::from_raw_parts; + let c_blob = ffi::sqlite3_value_blob(v); + let len = ffi::sqlite3_value_bytes(v); + + assert!(len >= 0, "unexpected negative return from sqlite3_value_bytes"); + let len = len as usize; + + Ok(from_raw_parts(mem::transmute(c_blob), len).to_vec()) + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_type(v) == ffi::SQLITE_BLOB + } +} + +impl FromValue for Option { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult> { + if sqlite3_value_type(v) == ffi::SQLITE_NULL { + Ok(None) + } else { + FromValue::parameter_value(v).map(|t| Some(t)) + } + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_type(v) == ffi::SQLITE_NULL || + T::parameter_has_valid_sqlite_type(v) + } +} + +// sqlite3_user_data +// sqlite3_get_auxdata +// sqlite3_set_auxdata + +pub type ScalarFunc = + Option; + +impl SqliteConnection { + // TODO pApp + pub fn create_scalar_function(&self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: ScalarFunc) -> SqliteResult<()> { + self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) + } +} + +impl InnerSqliteConnection { + pub fn create_scalar_function(&mut self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: ScalarFunc) -> SqliteResult<()> { + let c_name = try!(str_to_cstring(fn_name)); + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), c_name.as_ptr(), n_arg, flags, ptr::null_mut(), x_func, None, None, None) + }; + self.decode_result(r) + } +} + +#[cfg(test)] +mod test { + use libc::{c_int, c_double}; + use SqliteConnection; + use ffi; + use ffi::sqlite3_context as sqlite3_context; + use ffi::sqlite3_value as sqlite3_value; + use functions::{FromValue,ToResult}; + + extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + unsafe { + let arg = *argv.offset(0); + if c_double::parameter_has_valid_sqlite_type(arg) { + let value = c_double::parameter_value(arg).unwrap() / 2f64; + value.result(ctx); + } else { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISMATCH); + } + } + } + + #[test] + fn test_half() { + let db = SqliteConnection::open_in_memory().unwrap(); + db.create_scalar_function("half", 1, true, Some(half)).unwrap(); + let result = db.query_row("SELECT half(6)", + &[], + |r| r.get::(0)); + + assert_eq!(3f64, result.unwrap()); + } +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index f98780b..b63df2a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -79,6 +79,7 @@ pub use transaction::{SqliteTransactionBehavior, pub mod types; mod transaction; #[cfg(feature = "load_extension")] mod load_extension_guard; +#[cfg(feature = "functions")] pub mod functions; /// A typedef of the result returned by many methods. pub type SqliteResult = Result; From b9ab3350eaea7a861b861e3e92e7fb6c0aaf2290 Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Sun, 9 Aug 2015 13:06:23 +0200 Subject: [PATCH 02/13] Add regexp() function implementation --- Cargo.toml | 1 + src/functions.rs | 106 +++++++++++++++++++++++++++++++++++++++++------ src/lib.rs | 1 + src/types.rs | 22 ++++++++++ 4 files changed, 117 insertions(+), 13 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3241faf..366a38e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ libc = "~0.1" [dev-dependencies] tempdir = "~0.3.4" +regex = "~0.1.41" [dependencies.libsqlite3-sys] path = "libsqlite3-sys" diff --git a/src/functions.rs b/src/functions.rs index d96f859..763c7e6 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -17,13 +17,13 @@ use {SqliteResult, SqliteError, SqliteConnection, str_to_cstring, InnerSqliteCon /// A trait for types that can be converted into the result of an SQL function. pub trait ToResult { - unsafe fn result(&self, ctx: *mut sqlite3_context); + unsafe fn set_result(&self, ctx: *mut sqlite3_context); } macro_rules! raw_to_impl( ($t:ty, $f:ident) => ( impl ToResult for $t { - unsafe fn result(&self, ctx: *mut sqlite3_context) { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { ffi::$f(ctx, *self) } } @@ -34,9 +34,18 @@ raw_to_impl!(c_int, sqlite3_result_int); raw_to_impl!(i64, sqlite3_result_int64); raw_to_impl!(c_double, sqlite3_result_double); +impl<'a> ToResult for bool { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { + match *self { + true => ffi::sqlite3_result_int(ctx, 1), + _ => ffi::sqlite3_result_int(ctx, 0), + } + } +} + impl<'a> ToResult for &'a str { - unsafe fn result(&self, ctx: *mut sqlite3_context) { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { let length = self.len(); if length > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); @@ -51,13 +60,13 @@ impl<'a> ToResult for &'a str { } impl ToResult for String { - unsafe fn result(&self, ctx: *mut sqlite3_context) { - (&self[..]).result(ctx) + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { + (&self[..]).set_result(ctx) } } impl<'a> ToResult for &'a [u8] { - unsafe fn result(&self, ctx: *mut sqlite3_context) { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { if self.len() > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); return @@ -68,22 +77,22 @@ impl<'a> ToResult for &'a [u8] { } impl ToResult for Vec { - unsafe fn result(&self, ctx: *mut sqlite3_context) { - (&self[..]).result(ctx) + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { + (&self[..]).set_result(ctx) } } impl ToResult for Option { - unsafe fn result(&self, ctx: *mut sqlite3_context) { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { match *self { None => ffi::sqlite3_result_null(ctx), - Some(ref t) => t.result(ctx), + Some(ref t) => t.set_result(ctx), } } } impl ToResult for Null { - unsafe fn result(&self, ctx: *mut sqlite3_context) { + unsafe fn set_result(&self, ctx: *mut sqlite3_context) { ffi::sqlite3_result_null(ctx) } } @@ -127,6 +136,19 @@ macro_rules! raw_from_impl( raw_from_impl!(c_int, sqlite3_value_int, ffi::SQLITE_INTEGER); raw_from_impl!(i64, sqlite3_value_int64, ffi::SQLITE_INTEGER); +impl FromValue for bool { + unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { + match ffi::sqlite3_value_int(v) { + 0 => Ok(false), + _ => Ok(true), + } + } + + unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { + sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER + } +} + impl FromValue for c_double { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { Ok(ffi::sqlite3_value_double(v)) @@ -218,7 +240,14 @@ impl InnerSqliteConnection { #[cfg(test)] mod test { - use libc::{c_int, c_double}; + extern crate regex; + + use std::boxed::Box; + use std::ffi::{CString}; + use std::mem; + use libc::{c_int, c_double, c_void}; + use self::regex::Regex; + use SqliteConnection; use ffi; use ffi::sqlite3_context as sqlite3_context; @@ -230,7 +259,7 @@ mod test { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { let value = c_double::parameter_value(arg).unwrap() / 2f64; - value.result(ctx); + value.set_result(ctx); } else { ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISMATCH); } @@ -247,4 +276,55 @@ mod test { assert_eq!(3f64, result.unwrap()); } + + extern "C" fn regexp_free(raw: *mut c_void) { + unsafe { + Box::from_raw(raw); + } + } + + extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + unsafe { + let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; + let mut re_opt = None; + if re_ptr.is_null() { + let raw = String::parameter_value(*argv.offset(0)); + if raw.is_err() { + let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); + ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + return + } + let comp = Regex::new(raw.unwrap().as_ref()); + if comp.is_err() { + let msg = CString::new(format!("{}", comp.unwrap_err())).unwrap(); + ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + return + } + let re = comp.unwrap(); + re_ptr = &re as *const Regex; + re_opt = Some(re); + } + + let text = String::parameter_value(*argv.offset(1)); + if text.is_ok() { + let text = text.unwrap(); + (*re_ptr).is_match(text.as_ref()).set_result(ctx); + } + + if re_opt.is_some() { + ffi::sqlite3_set_auxdata(ctx, 0, mem::transmute(Box::into_raw(Box::new(re_opt.unwrap()))), Some(regexp_free)); + } + } + } + + #[test] + fn test_regexp() { + let db = SqliteConnection::open_in_memory().unwrap(); + db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap(); + let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", + &[], + |r| r.get::(0)); + + assert_eq!(true, result.unwrap()); + } } \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index b63df2a..ce31617 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -50,6 +50,7 @@ //! } //! } //! ``` +#![cfg_attr(test, feature(box_raw))] extern crate libc; extern crate libsqlite3_sys as ffi; #[macro_use] extern crate bitflags; diff --git a/src/types.rs b/src/types.rs index 1d30b30..ea8c6ac 100644 --- a/src/types.rs +++ b/src/types.rs @@ -100,6 +100,15 @@ raw_to_impl!(c_int, sqlite3_bind_int); raw_to_impl!(i64, sqlite3_bind_int64); raw_to_impl!(c_double, sqlite3_bind_double); +impl ToSql for bool { + unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { + match *self { + true => ffi::sqlite3_bind_int(stmt, col, 1), + _ => ffi::sqlite3_bind_int(stmt, col, 0), + } + } +} + impl<'a> ToSql for &'a str { unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { match str_to_cstring(self) { @@ -188,6 +197,19 @@ raw_from_impl!(c_int, sqlite3_column_int, ffi::SQLITE_INTEGER); raw_from_impl!(i64, sqlite3_column_int64, ffi::SQLITE_INTEGER); raw_from_impl!(c_double, sqlite3_column_double, ffi::SQLITE_FLOAT); +impl FromSql for bool { + unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult { + match ffi::sqlite3_column_int(stmt, col) { + 0 => Ok(false), + _ => Ok(true), + } + } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + sqlite3_column_type(stmt, col) == ffi::SQLITE_INTEGER + } +} + impl FromSql for String { unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult { let c_text = ffi::sqlite3_column_text(stmt, col); From 0c3575e845bd49c4d44da2ed9aa4631f7f964709 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 11:40:53 -0500 Subject: [PATCH 03/13] Fix segfault in regexp user function test --- src/functions.rs | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index fa6220e..3f21889 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -278,16 +278,16 @@ mod test { } extern "C" fn regexp_free(raw: *mut c_void) { - unsafe { - Box::from_raw(raw); - } + let _: Box = unsafe { + Box::from_raw(mem::transmute(raw)) + }; } extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; - let mut re_opt = None; - if re_ptr.is_null() { + let need_re = re_ptr.is_null(); + if need_re { let raw = String::parameter_value(*argv.offset(0)); if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); @@ -300,9 +300,8 @@ mod test { ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); return } - let re = comp.unwrap(); - re_ptr = &re as *const Regex; - re_opt = Some(re); + let re = Box::new(comp.unwrap()); + re_ptr = Box::into_raw(re); } let text = String::parameter_value(*argv.offset(1)); @@ -311,8 +310,8 @@ mod test { (*re_ptr).is_match(text.as_ref()).set_result(ctx); } - if re_opt.is_some() { - ffi::sqlite3_set_auxdata(ctx, 0, mem::transmute(Box::into_raw(Box::new(re_opt.unwrap()))), Some(regexp_free)); + if need_re { + ffi::sqlite3_set_auxdata(ctx, 0, mem::transmute(re_ptr), Some(regexp_free)); } } } From aae431760ecf95742507653aa5718ea91ef90f6d Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 11:41:40 -0500 Subject: [PATCH 04/13] rustfmt - no code changes --- src/functions.rs | 103 ++++++++++++++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 38 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 3f21889..52ed7ca 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,15 +1,15 @@ //! Create or redefine SQL functions -use std::ffi::{CStr}; +use std::ffi::CStr; use std::mem; use std::ptr; use std::str; use libc::{c_int, c_double, c_char}; use ffi; -pub use ffi::sqlite3_context as sqlite3_context; -pub use ffi::sqlite3_value as sqlite3_value; -pub use ffi::sqlite3_value_type as sqlite3_value_type; -pub use ffi::sqlite3_value_numeric_type as sqlite3_value_numeric_type; +pub use ffi::sqlite3_context; +pub use ffi::sqlite3_value; +pub use ffi::sqlite3_value_type; +pub use ffi::sqlite3_value_numeric_type; use types::Null; @@ -49,12 +49,16 @@ impl<'a> ToResult for &'a str { let length = self.len(); if length > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); - return + return; } match str_to_cstring(self) { - Ok(c_str) => ffi::sqlite3_result_text(ctx, c_str.as_ptr(), length as c_int, - ffi::SQLITE_TRANSIENT()), - Err(_) => ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), // TODO sqlite3_result_error + Ok(c_str) => { + ffi::sqlite3_result_text(ctx, + c_str.as_ptr(), + length as c_int, + ffi::SQLITE_TRANSIENT()) + } + Err(_) => ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), // TODO sqlite3_result_error } } } @@ -69,10 +73,12 @@ impl<'a> ToResult for &'a [u8] { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { if self.len() > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); - return + return; } - ffi::sqlite3_result_blob( - ctx, mem::transmute(self.as_ptr()), self.len() as c_int, ffi::SQLITE_TRANSIENT()) + ffi::sqlite3_result_blob(ctx, + mem::transmute(self.as_ptr()), + self.len() as c_int, + ffi::SQLITE_TRANSIENT()) } } @@ -155,7 +161,8 @@ impl FromValue for c_double { } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { - sqlite3_value_numeric_type(v) == ffi::SQLITE_FLOAT || sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER + sqlite3_value_numeric_type(v) == ffi::SQLITE_FLOAT || + sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER } } @@ -167,9 +174,13 @@ impl FromValue for String { } else { let c_slice = CStr::from_ptr(c_text as *const c_char).to_bytes(); let utf8_str = str::from_utf8(c_slice); - utf8_str - .map(|s| { s.to_string() }) - .map_err(|e| { SqliteError{code: 0, message: e.to_string()} }) + utf8_str.map(|s| s.to_string()) + .map_err(|e| { + SqliteError { + code: 0, + message: e.to_string(), + } + }) } } @@ -184,7 +195,8 @@ impl FromValue for Vec { let c_blob = ffi::sqlite3_value_blob(v); let len = ffi::sqlite3_value_bytes(v); - assert!(len >= 0, "unexpected negative return from sqlite3_value_bytes"); + assert!(len >= 0, + "unexpected negative return from sqlite3_value_bytes"); let len = len as usize; Ok(from_raw_parts(mem::transmute(c_blob), len).to_vec()) @@ -205,8 +217,7 @@ impl FromValue for Option { } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { - sqlite3_value_type(v) == ffi::SQLITE_NULL || - T::parameter_has_valid_sqlite_type(v) + sqlite3_value_type(v) == ffi::SQLITE_NULL || T::parameter_has_valid_sqlite_type(v) } } @@ -214,25 +225,45 @@ impl FromValue for Option { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub type ScalarFunc = - Option; +pub type ScalarFunc = Option; impl SqliteConnection { // TODO pApp - pub fn create_scalar_function(&self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: ScalarFunc) -> SqliteResult<()> { + pub fn create_scalar_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: ScalarFunc) + -> SqliteResult<()> { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: ScalarFunc) -> SqliteResult<()> { + pub fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: ScalarFunc) + -> SqliteResult<()> { let c_name = try!(str_to_cstring(fn_name)); let mut flags = ffi::SQLITE_UTF8; if deterministic { flags |= ffi::SQLITE_DETERMINISTIC; } let r = unsafe { - ffi::sqlite3_create_function_v2(self.db(), c_name.as_ptr(), n_arg, flags, ptr::null_mut(), x_func, None, None, None) + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + flags, + ptr::null_mut(), + x_func, + None, + None, + None) }; self.decode_result(r) } @@ -243,16 +274,16 @@ mod test { extern crate regex; use std::boxed::Box; - use std::ffi::{CString}; + use std::ffi::CString; use std::mem; use libc::{c_int, c_double, c_void}; use self::regex::Regex; use SqliteConnection; use ffi; - use ffi::sqlite3_context as sqlite3_context; - use ffi::sqlite3_value as sqlite3_value; - use functions::{FromValue,ToResult}; + use ffi::sqlite3_context; + use ffi::sqlite3_value; + use functions::{FromValue, ToResult}; extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { @@ -270,17 +301,13 @@ mod test { fn test_half() { let db = SqliteConnection::open_in_memory().unwrap(); db.create_scalar_function("half", 1, true, Some(half)).unwrap(); - let result = db.query_row("SELECT half(6)", - &[], - |r| r.get::(0)); + let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); assert_eq!(3f64, result.unwrap()); } extern "C" fn regexp_free(raw: *mut c_void) { - let _: Box = unsafe { - Box::from_raw(mem::transmute(raw)) - }; + let _: Box = unsafe { Box::from_raw(mem::transmute(raw)) }; } extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { @@ -292,13 +319,13 @@ mod test { if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); - return + return; } let comp = Regex::new(raw.unwrap().as_ref()); if comp.is_err() { let msg = CString::new(format!("{}", comp.unwrap_err())).unwrap(); ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); - return + return; } let re = Box::new(comp.unwrap()); re_ptr = Box::into_raw(re); @@ -321,8 +348,8 @@ mod test { let db = SqliteConnection::open_in_memory().unwrap(); db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", - &[], - |r| r.get::(0)); + &[], + |r| r.get::(0)); assert_eq!(true, result.unwrap()); } From 29494f46f6cf7528caed429048f1a60afe8e7134 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 12:01:05 -0500 Subject: [PATCH 05/13] Let create_scalar_function take an FnMut instead of a extern "C" fn. --- src/functions.rs | 75 ++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 52ed7ca..bb386a6 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,9 +1,8 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; -use std::ptr; use std::str; -use libc::{c_int, c_double, c_char}; +use libc::{c_int, c_double, c_char, c_void}; use ffi; pub use ffi::sqlite3_context; @@ -225,30 +224,50 @@ impl FromValue for Option { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub type ScalarFunc = Option; +pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {} +impl ScalarFunction for F {} impl SqliteConnection { - // TODO pApp - pub fn create_scalar_function(&self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { + extern "C" fn free_boxed_closure(p: *mut c_void) + where F: ScalarFunction + { + let _: Box = unsafe { Box::from_raw(mem::transmute(p)) }; + } + + extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where F: ScalarFunction + { + unsafe { + let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + (*boxed_f)(ctx, argc, argv); + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let c_name = try!(str_to_cstring(fn_name)); let mut flags = ffi::SQLITE_UTF8; if deterministic { @@ -259,11 +278,11 @@ impl InnerSqliteConnection { c_name.as_ptr(), n_arg, flags, - ptr::null_mut(), - x_func, + mem::transmute(boxed_f), + Some(call_boxed_closure::), None, None, - None) + Some(free_boxed_closure::)) }; self.decode_result(r) } @@ -285,7 +304,7 @@ mod test { use ffi::sqlite3_value; use functions::{FromValue, ToResult}; - extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { @@ -298,9 +317,9 @@ mod test { } #[test] - fn test_half() { + fn test_function_half() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("half", 1, true, Some(half)).unwrap(); + db.create_scalar_function("half", 1, true, half).unwrap(); let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); assert_eq!(3f64, result.unwrap()); @@ -310,7 +329,7 @@ mod test { let _: Box = unsafe { Box::from_raw(mem::transmute(raw)) }; } - extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; let need_re = re_ptr.is_null(); @@ -344,9 +363,9 @@ mod test { } #[test] - fn test_regexp() { + fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap(); + db.create_scalar_function("regexp", 2, true, regexp).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::(0)); From 94d40c41c78038c2faedb18c5d748395e5e64038 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 13:54:08 -0500 Subject: [PATCH 06/13] Introduce Context wrapper for user-defined functions. This commit adds get/set auxilliary data for arguments; more to come. --- src/functions.rs | 104 ++++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 38 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index bb386a6..4c024b2 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -224,8 +224,37 @@ impl FromValue for Option { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {} -impl ScalarFunction for F {} +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + let _: Box = Box::from_raw(mem::transmute(p)); +} + +pub struct Context { + pub ctx: *mut sqlite3_context, +} + +impl Context { + pub fn set_aux(&self, arg: c_int, value: T) { + let boxed = Box::into_raw(Box::new(value)); + unsafe { + ffi::sqlite3_set_auxdata(self.ctx, + arg, + mem::transmute(boxed), + Some(mem::transmute(free_boxed_value::))) + }; + } + + pub unsafe fn get_aux(&self, arg: c_int) -> Option<&T> { + let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T; + if p.is_null() { + None + } else { + Some(&*p) + } + } +} + +pub trait ScalarFunction: FnMut(&Context, c_int, *mut *mut sqlite3_value) {} +impl ScalarFunction for F {} impl SqliteConnection { pub fn create_scalar_function(&self, @@ -249,21 +278,16 @@ impl InnerSqliteConnection { -> SqliteResult<()> where F: ScalarFunction { - extern "C" fn free_boxed_closure(p: *mut c_void) - where F: ScalarFunction - { - let _: Box = unsafe { Box::from_raw(mem::transmute(p)) }; - } - extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) where F: ScalarFunction { + let ctx = Context { ctx: ctx }; unsafe { - let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx)); + let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - (*boxed_f)(ctx, argc, argv); + (*boxed_f)(&ctx, argc, argv); } } @@ -282,7 +306,7 @@ impl InnerSqliteConnection { Some(call_boxed_closure::), None, None, - Some(free_boxed_closure::)) + Some(mem::transmute(free_boxed_value::))) }; self.decode_result(r) } @@ -292,26 +316,23 @@ impl InnerSqliteConnection { mod test { extern crate regex; - use std::boxed::Box; use std::ffi::CString; - use std::mem; - use libc::{c_int, c_double, c_void}; + use libc::{c_int, c_double}; use self::regex::Regex; use SqliteConnection; use ffi; - use ffi::sqlite3_context; use ffi::sqlite3_value; - use functions::{FromValue, ToResult}; + use functions::{Context, FromValue, ToResult}; - fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn half(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { let value = c_double::parameter_value(arg).unwrap() / 2f64; - value.set_result(ctx); + value.set_result(ctx.ctx); } else { - ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISMATCH); + ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_MISMATCH); } } } @@ -325,40 +346,47 @@ mod test { assert_eq!(3f64, result.unwrap()); } - extern "C" fn regexp_free(raw: *mut c_void) { - let _: Box = unsafe { Box::from_raw(mem::transmute(raw)) }; - } - - fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { - unsafe { - let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; - let need_re = re_ptr.is_null(); - if need_re { + fn regexp(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { + let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; + let new_re = match saved_re { + None => unsafe { let raw = String::parameter_value(*argv.offset(0)); if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); return; } let comp = Regex::new(raw.unwrap().as_ref()); if comp.is_err() { let msg = CString::new(format!("{}", comp.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); return; } - let re = Box::new(comp.unwrap()); - re_ptr = Box::into_raw(re); - } + Some(comp.unwrap()) + }, + Some(_) => None, + }; - let text = String::parameter_value(*argv.offset(1)); + { + let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); + + let text = unsafe { String::parameter_value(*argv.offset(1)) }; if text.is_ok() { let text = text.unwrap(); - (*re_ptr).is_match(text.as_ref()).set_result(ctx); + unsafe { + re.is_match(text.as_ref()).set_result(ctx.ctx); + } + } else { + let msg = CString::new(format!("{}", text.unwrap_err())).unwrap(); + unsafe { + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); + } + return; } + } - if need_re { - ffi::sqlite3_set_auxdata(ctx, 0, mem::transmute(re_ptr), Some(regexp_free)); - } + if let Some(re) = new_re { + ctx.set_aux(0, re); } } From 81ec7fe7cd12bf2bb28e07170984df0adcdfa47e Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 14:46:28 -0500 Subject: [PATCH 07/13] Add `get` to function::Context. This allows user-defined functions to now only accept a `Context`, as it embeds the arguments inside itself. --- src/functions.rs | 93 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 4c024b2..6481b8a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,6 +1,7 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; +use std::slice; use std::str; use libc::{c_int, c_double, c_char, c_void}; @@ -228,11 +229,30 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { let _: Box = Box::from_raw(mem::transmute(p)); } -pub struct Context { +pub struct Context<'a> { pub ctx: *mut sqlite3_context, + args: &'a [*mut sqlite3_value], } -impl Context { +impl<'a> Context<'a> { + pub fn len(&self) -> usize { + self.args.len() + } + + pub fn get(&self, idx: usize) -> SqliteResult { + let arg = self.args[idx]; + unsafe { + if T::parameter_has_valid_sqlite_type(arg) { + T::parameter_value(arg) + } else { + Err(SqliteError { + code: ffi::SQLITE_MISMATCH, + message: "Invalid value type".to_string(), + }) + } + } + } + pub fn set_aux(&self, arg: c_int, value: T) { let boxed = Box::into_raw(Box::new(value)); unsafe { @@ -253,9 +273,6 @@ impl Context { } } -pub trait ScalarFunction: FnMut(&Context, c_int, *mut *mut sqlite3_value) {} -impl ScalarFunction for F {} - impl SqliteConnection { pub fn create_scalar_function(&self, fn_name: &str, @@ -263,31 +280,34 @@ impl SqliteConnection { deterministic: bool, x_func: F) -> SqliteResult<()> - where F: ScalarFunction + where F: FnMut(&Context) { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: F) - -> SqliteResult<()> - where F: ScalarFunction + fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: FnMut(&Context) { extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) - where F: ScalarFunction + where F: FnMut(&Context) { - let ctx = Context { ctx: ctx }; unsafe { + let ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - (*boxed_f)(&ctx, argc, argv); + (*boxed_f)(&ctx); } } @@ -317,23 +337,18 @@ mod test { extern crate regex; use std::ffi::CString; - use libc::{c_int, c_double}; + use libc::c_double; use self::regex::Regex; use SqliteConnection; use ffi; - use ffi::sqlite3_value; - use functions::{Context, FromValue, ToResult}; + use functions::{Context, ToResult}; - fn half(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { - unsafe { - let arg = *argv.offset(0); - if c_double::parameter_has_valid_sqlite_type(arg) { - let value = c_double::parameter_value(arg).unwrap() / 2f64; - value.set_result(ctx.ctx); - } else { - ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_MISMATCH); - } + fn half(ctx: &Context) { + assert!(ctx.len() == 1, "called with unexpected number of arguments"); + match ctx.get::(0) { + Ok(value) => unsafe { (value / 2f64).set_result(ctx.ctx) }, + Err(err) => unsafe { ffi::sqlite3_result_error_code(ctx.ctx, err.code) }, } } @@ -346,11 +361,13 @@ mod test { assert_eq!(3f64, result.unwrap()); } - fn regexp(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { + fn regexp(ctx: &Context) { + assert!(ctx.len() == 2, "called with unexpected number of arguments"); + let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; let new_re = match saved_re { None => unsafe { - let raw = String::parameter_value(*argv.offset(0)); + let raw = ctx.get::(0); if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); @@ -370,7 +387,7 @@ mod test { { let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); - let text = unsafe { String::parameter_value(*argv.offset(1)) }; + let text = ctx.get::(1); if text.is_ok() { let text = text.unwrap(); unsafe { @@ -391,13 +408,27 @@ mod test { } #[test] + #[cfg_attr(rustfmt, rustfmt_skip)] fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); + db.execute_batch("BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;").unwrap(); db.create_scalar_function("regexp", 2, true, regexp).unwrap(); + let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::(0)); assert_eq!(true, result.unwrap()); + + let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + &[], + |r| r.get::(0)); + + assert_eq!(2, result.unwrap()); } } From 3913e89f9447d75cef5adffeda8e1808e418e88c Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 15:08:40 -0500 Subject: [PATCH 08/13] Allow user scalar functions to return results. This removes the need for scalar functions to have direct access to the context (in order to set the return value). --- src/functions.rs | 114 +++++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 63 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 6481b8a..0552dc2 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -221,16 +221,12 @@ impl FromValue for Option { } } -// sqlite3_user_data -// sqlite3_get_auxdata -// sqlite3_set_auxdata - unsafe extern "C" fn free_boxed_value(p: *mut c_void) { let _: Box = Box::from_raw(mem::transmute(p)); } pub struct Context<'a> { - pub ctx: *mut sqlite3_context, + ctx: *mut sqlite3_context, args: &'a [*mut sqlite3_value], } @@ -274,31 +270,34 @@ impl<'a> Context<'a> { } impl SqliteConnection { - pub fn create_scalar_function(&self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: F) - -> SqliteResult<()> - where F: FnMut(&Context) + pub fn create_scalar_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: FnMut(&Context) -> SqliteResult, + T: ToResult { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: F) - -> SqliteResult<()> - where F: FnMut(&Context) + fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: FnMut(&Context) -> SqliteResult, + T: ToResult { - extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) - where F: FnMut(&Context) + extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where F: FnMut(&Context) -> SqliteResult, + T: ToResult { unsafe { let ctx = Context { @@ -307,7 +306,15 @@ impl InnerSqliteConnection { }; let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - (*boxed_f)(&ctx); + match (*boxed_f)(&ctx) { + Ok(r) => r.set_result(ctx.ctx), + Err(e) => { + ffi::sqlite3_result_error_code(ctx.ctx, e.code); + if let Ok(cstr) = str_to_cstring(&e.message) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + }, + } } } @@ -323,7 +330,7 @@ impl InnerSqliteConnection { n_arg, flags, mem::transmute(boxed_f), - Some(call_boxed_closure::), + Some(call_boxed_closure::), None, None, Some(mem::transmute(free_boxed_value::))) @@ -336,20 +343,17 @@ impl InnerSqliteConnection { mod test { extern crate regex; - use std::ffi::CString; use libc::c_double; use self::regex::Regex; - use SqliteConnection; + use {SqliteConnection, SqliteError, SqliteResult}; use ffi; - use functions::{Context, ToResult}; + use functions::Context; - fn half(ctx: &Context) { + fn half(ctx: &Context) -> SqliteResult { assert!(ctx.len() == 1, "called with unexpected number of arguments"); - match ctx.get::(0) { - Ok(value) => unsafe { (value / 2f64).set_result(ctx.ctx) }, - Err(err) => unsafe { ffi::sqlite3_result_error_code(ctx.ctx, err.code) }, - } + let value = try!(ctx.get::(0)); + Ok(value / 2f64) } #[test] @@ -361,50 +365,34 @@ mod test { assert_eq!(3f64, result.unwrap()); } - fn regexp(ctx: &Context) { + fn regexp(ctx: &Context) -> SqliteResult { assert!(ctx.len() == 2, "called with unexpected number of arguments"); let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; let new_re = match saved_re { - None => unsafe { - let raw = ctx.get::(0); - if raw.is_err() { - let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); - return; - } - let comp = Regex::new(raw.unwrap().as_ref()); - if comp.is_err() { - let msg = CString::new(format!("{}", comp.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); - return; - } - Some(comp.unwrap()) + None => { + let s = try!(ctx.get::(0)); + let r = try!(Regex::new(&s).map_err(|e| SqliteError { + code: ffi::SQLITE_ERROR, + message: format!("Invalid regular expression: {}", e), + })); + Some(r) }, Some(_) => None, }; - { + let is_match = { let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); - let text = ctx.get::(1); - if text.is_ok() { - let text = text.unwrap(); - unsafe { - re.is_match(text.as_ref()).set_result(ctx.ctx); - } - } else { - let msg = CString::new(format!("{}", text.unwrap_err())).unwrap(); - unsafe { - ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); - } - return; - } - } + let text = try!(ctx.get::(1)); + re.is_match(&text) + }; if let Some(re) = new_re { ctx.set_aux(0, re); } + + Ok(is_match) } #[test] From 3baf7b10f81befb6d338e9867dfa29150aa83e0d Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 15:35:59 -0500 Subject: [PATCH 09/13] Add unit test demonstrating a closure-based UDF. --- src/functions.rs | 73 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 8 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 0552dc2..2809f47 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -313,7 +313,7 @@ impl InnerSqliteConnection { if let Ok(cstr) = str_to_cstring(&e.message) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } - }, + } } } } @@ -343,6 +343,7 @@ impl InnerSqliteConnection { mod test { extern crate regex; + use std::collections::HashMap; use libc::c_double; use self::regex::Regex; @@ -365,19 +366,24 @@ mod test { assert_eq!(3f64, result.unwrap()); } - fn regexp(ctx: &Context) -> SqliteResult { + // This implementation of a regexp scalar function uses SQLite's auxilliary data + // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular + // expression multiple times within one query. + fn regexp_with_auxilliary(ctx: &Context) -> SqliteResult { assert!(ctx.len() == 2, "called with unexpected number of arguments"); let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; let new_re = match saved_re { None => { let s = try!(ctx.get::(0)); - let r = try!(Regex::new(&s).map_err(|e| SqliteError { - code: ffi::SQLITE_ERROR, - message: format!("Invalid regular expression: {}", e), + let r = try!(Regex::new(&s).map_err(|e| { + SqliteError { + code: ffi::SQLITE_ERROR, + message: format!("Invalid regular expression: {}", e), + } })); Some(r) - }, + } Some(_) => None, }; @@ -397,7 +403,7 @@ mod test { #[test] #[cfg_attr(rustfmt, rustfmt_skip)] - fn test_function_regexp() { + fn test_function_regexp_with_auxilliary() { let db = SqliteConnection::open_in_memory().unwrap(); db.execute_batch("BEGIN; CREATE TABLE foo (x string); @@ -405,7 +411,58 @@ mod test { INSERT INTO foo VALUES ('lXsi'); INSERT INTO foo VALUES ('lisX'); END;").unwrap(); - db.create_scalar_function("regexp", 2, true, regexp).unwrap(); + db.create_scalar_function("regexp", 2, true, regexp_with_auxilliary).unwrap(); + + let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", + &[], + |r| r.get::(0)); + + assert_eq!(true, result.unwrap()); + + let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + &[], + |r| r.get::(0)); + + assert_eq!(2, result.unwrap()); + } + + #[test] + #[cfg_attr(rustfmt, rustfmt_skip)] + fn test_function_regexp_with_hashmap_cache() { + let db = SqliteConnection::open_in_memory().unwrap(); + db.execute_batch("BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;").unwrap(); + + // This implementation of a regexp scalar function uses a captured HashMap + // to keep cached regular expressions around (even across multiple queries) + // until the function is removed. + let mut cached_regexes = HashMap::new(); + db.create_scalar_function("regexp", 2, true, move |ctx| { + assert!(ctx.len() == 2, "called with unexpected number of arguments"); + + let regex_s = try!(ctx.get::(0)); + let entry = cached_regexes.entry(regex_s.clone()); + let regex = { + use std::collections::hash_map::Entry::{Occupied, Vacant}; + match entry { + Occupied(occ) => occ.into_mut(), + Vacant(vac) => { + let r = try!(Regex::new(®ex_s).map_err(|e| SqliteError { + code: ffi::SQLITE_ERROR, + message: format!("Invalid regular expression: {}", e), + })); + vac.insert(r) + } + } + }; + + let text = try!(ctx.get::(1)); + Ok(regex.is_match(&text)) + }).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], From ecef092303c872eb0a5dcff426b1f4ecdbd74b4e Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 15:47:52 -0500 Subject: [PATCH 10/13] Add `remove_function` to clear a user-defined function. --- src/functions.rs | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/functions.rs b/src/functions.rs index 2809f47..9fc2bc2 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,6 +1,7 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; +use std::ptr; use std::slice; use std::str; use libc::{c_int, c_double, c_char, c_void}; @@ -281,6 +282,10 @@ impl SqliteConnection { { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + + pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> SqliteResult<()> { + self.db.borrow_mut().remove_function(fn_name, n_arg) + } } impl InnerSqliteConnection { @@ -337,6 +342,22 @@ impl InnerSqliteConnection { }; self.decode_result(r) } + + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> SqliteResult<()> { + let c_name = try!(str_to_cstring(fn_name)); + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + ffi::SQLITE_UTF8, + ptr::null_mut(), + None, + None, + None, + None) + }; + self.decode_result(r) + } } #[cfg(test)] @@ -366,6 +387,18 @@ mod test { assert_eq!(3f64, result.unwrap()); } + #[test] + fn test_remove_function() { + let db = SqliteConnection::open_in_memory().unwrap(); + db.create_scalar_function("half", 1, true, half).unwrap(); + let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); + assert_eq!(3f64, result.unwrap()); + + db.remove_function("half", 1).unwrap(); + let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); + assert!(result.is_err()); + } + // This implementation of a regexp scalar function uses SQLite's auxilliary data // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular // expression multiple times within one query. From 3bcde498bda089020f70ff3bc0d71e879b3c7caf Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 16:27:39 -0500 Subject: [PATCH 11/13] Expand comments. --- src/functions.rs | 113 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 111 insertions(+), 2 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 9fc2bc2..6826cf6 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,4 +1,55 @@ -//! Create or redefine SQL functions +//! Create or redefine SQL functions. +//! +//! # Example +//! +//! Adding a `regexp` function to a connection in which compiled regular expressions +//! are cached in a `HashMap`. For an alternative implementation that uses SQLite's +//! [Function Auxilliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface +//! to avoid recompiling regular expressions, see the unit tests for this module. +//! +//! ```rust +//! extern crate libsqlite3_sys; +//! extern crate rusqlite; +//! extern crate regex; +//! +//! use rusqlite::{SqliteConnection, SqliteError, SqliteResult}; +//! use std::collections::HashMap; +//! use regex::Regex; +//! +//! fn add_regexp_function(db: &SqliteConnection) -> SqliteResult<()> { +//! let mut cached_regexes = HashMap::new(); +//! db.create_scalar_function("regexp", 2, true, move |ctx| { +//! let regex_s = try!(ctx.get::(0)); +//! let entry = cached_regexes.entry(regex_s.clone()); +//! let regex = { +//! use std::collections::hash_map::Entry::{Occupied, Vacant}; +//! match entry { +//! Occupied(occ) => occ.into_mut(), +//! Vacant(vac) => { +//! let r = try!(Regex::new(®ex_s).map_err(|e| SqliteError { +//! code: libsqlite3_sys::SQLITE_ERROR, +//! message: format!("Invalid regular expression: {}", e), +//! })); +//! vac.insert(r) +//! } +//! } +//! }; +//! +//! let text = try!(ctx.get::(1)); +//! Ok(regex.is_match(&text)) +//! }) +//! } +//! +//! fn main() { +//! let db = SqliteConnection::open_in_memory().unwrap(); +//! add_regexp_function(&db).unwrap(); +//! +//! let is_match = db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", &[], +//! |row| row.get::(0)).unwrap(); +//! +//! assert!(is_match); +//! } +//! ``` use std::ffi::CStr; use std::mem; use std::ptr; @@ -118,7 +169,7 @@ pub trait FromValue: Sized { /// FromValue types can implement this method and use sqlite3_value_type to check that /// the type reported by SQLite matches a type suitable for Self. This method is used - /// by `???` to confirm that the parameter contains a valid type before + /// by `Context::get` to confirm that the parameter contains a valid type before /// attempting to retrieve the value. unsafe fn parameter_has_valid_sqlite_type(_: *mut sqlite3_value) -> bool { true @@ -226,16 +277,25 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { let _: Box = Box::from_raw(mem::transmute(p)); } +/// Context is a wrapper for the SQLite function evaluation context. pub struct Context<'a> { ctx: *mut sqlite3_context, args: &'a [*mut sqlite3_value], } impl<'a> Context<'a> { + /// Returns the number of arguments to the function. pub fn len(&self) -> usize { self.args.len() } + /// Returns the `idx`th argument as a `T`. + /// + /// # Failure + /// + /// Will panic if `idx` is greater than or equal to `self.len()`. + /// + /// Will return Err if the underlying SQLite type cannot be converted to a `T`. pub fn get(&self, idx: usize) -> SqliteResult { let arg = self.args[idx]; unsafe { @@ -250,6 +310,9 @@ impl<'a> Context<'a> { } } + /// Sets the auxilliary data associated with a particular parameter. See + /// https://www.sqlite.org/c3ref/get_auxdata.html for a discussion of + /// this feature, or the unit tests of this module for an example. pub fn set_aux(&self, arg: c_int, value: T) { let boxed = Box::into_raw(Box::new(value)); unsafe { @@ -260,6 +323,14 @@ impl<'a> Context<'a> { }; } + /// Gets the auxilliary data that was associated with a given parameter + /// via `set_aux`. Returns `None` if no data has been associated. + /// + /// # Unsafety + /// + /// This function is unsafe as there is no guarantee that the type `T` + /// requested matches the type `T` that was provided to `set_aux`. The + /// types must be identical. pub unsafe fn get_aux(&self, arg: c_int) -> Option<&T> { let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T; if p.is_null() { @@ -271,6 +342,36 @@ impl<'a> Context<'a> { } impl SqliteConnection { + /// Attach a user-defined scalar function to this database connection. + /// + /// `fn_name` is the name the function will be accessible from SQL. + /// `n_arg` is the number of arguments to the function. Use `-1` for a variable + /// number. If the function always returns the same value given the same + /// input, `deterministic` should be `true`. + /// + /// The function will remain available until the connection is closed or + /// until it is explicitly removed via `remove_function`. + /// + /// # Example + /// + /// ```rust + /// # use rusqlite::{SqliteConnection, SqliteResult}; + /// # type c_double = f64; + /// fn scalar_function_example(db: SqliteConnection) -> SqliteResult<()> { + /// try!(db.create_scalar_function("halve", 1, true, |ctx| { + /// let value = try!(ctx.get::(0)); + /// Ok(value / 2f64) + /// })); + /// + /// let six_halved = try!(db.query_row("SELECT halve(6)", &[], |r| r.get::(0))); + /// assert_eq!(six_halved, 3f64); + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. pub fn create_scalar_function(&self, fn_name: &str, n_arg: c_int, @@ -283,6 +384,14 @@ impl SqliteConnection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + /// Removes a user-defined function from this database connection. + /// + /// `fn_name` and `n_arg` should match the name and number of arguments + /// given to `create_scalar_function`. + /// + /// # Failure + /// + /// Will return Err if the function could not be removed. pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> SqliteResult<()> { self.db.borrow_mut().remove_function(fn_name, n_arg) } From caf1e95e314ad2602279661d3bf826eda5810a96 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 16:28:46 -0500 Subject: [PATCH 12/13] Add functions feature to travis and Changelog. --- .travis.yml | 5 +++-- Changelog.md | 2 ++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.travis.yml b/.travis.yml index a0743ea..2bd2336 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,8 +11,9 @@ script: - cargo test --features backup - cargo test --features load_extension - cargo test --features trace - - cargo test --features "backup load_extension trace" - - cargo doc --no-deps --features "backup load_extension trace" + - cargo test --features functions + - cargo test --features "backup functions load_extension trace" + - cargo doc --no-deps --features "backup functions load_extension trace" after_success: | [ $TRAVIS_BRANCH = master ] && diff --git a/Changelog.md b/Changelog.md index 822668b..77315c2 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,6 +1,8 @@ # Version UPCOMING (TBD) * Adds `backup` feature that exposes SQLite's online backup API. +* Adds `functions` feature that allows user-defined scalar functions to be added to + open `SqliteConnection`s. # Version 0.5.0 (2015-12-08) From 4830b0a64815d5c20eb55f225ef824db98679477 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Sat, 12 Dec 2015 10:44:08 -0500 Subject: [PATCH 13/13] Add unit test for function with variable number of arguments --- src/functions.rs | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/src/functions.rs b/src/functions.rs index 6826cf6..d8df619 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -618,4 +618,27 @@ mod test { assert_eq!(2, result.unwrap()); } + + #[test] + fn test_varargs_function() { + let db = SqliteConnection::open_in_memory().unwrap(); + db.create_scalar_function("my_concat", -1, true, |ctx| { + let mut ret = String::new(); + + for idx in 0..ctx.len() { + let s = try!(ctx.get::(idx)); + ret.push_str(&s); + } + + Ok(ret) + }) + .unwrap(); + + for &(expected, query) in &[("", "SELECT my_concat()"), + ("onetwo", "SELECT my_concat('one', 'two')"), + ("abc", "SELECT my_concat('a', 'b', 'c')")] { + let result: String = db.query_row(query, &[], |r| r.get(0)).unwrap(); + assert_eq!(expected, result); + } + } }