diff --git a/src/functions.rs b/src/functions.rs index 52ed7ca..bb386a6 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,9 +1,8 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; -use std::ptr; use std::str; -use libc::{c_int, c_double, c_char}; +use libc::{c_int, c_double, c_char, c_void}; use ffi; pub use ffi::sqlite3_context; @@ -225,30 +224,50 @@ impl FromValue for Option { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub type ScalarFunc = Option; +pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {} +impl ScalarFunction for F {} impl SqliteConnection { - // TODO pApp - pub fn create_scalar_function(&self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { + extern "C" fn free_boxed_closure(p: *mut c_void) + where F: ScalarFunction + { + let _: Box = unsafe { Box::from_raw(mem::transmute(p)) }; + } + + extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where F: ScalarFunction + { + unsafe { + let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + (*boxed_f)(ctx, argc, argv); + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let c_name = try!(str_to_cstring(fn_name)); let mut flags = ffi::SQLITE_UTF8; if deterministic { @@ -259,11 +278,11 @@ impl InnerSqliteConnection { c_name.as_ptr(), n_arg, flags, - ptr::null_mut(), - x_func, + mem::transmute(boxed_f), + Some(call_boxed_closure::), None, None, - None) + Some(free_boxed_closure::)) }; self.decode_result(r) } @@ -285,7 +304,7 @@ mod test { use ffi::sqlite3_value; use functions::{FromValue, ToResult}; - extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { @@ -298,9 +317,9 @@ mod test { } #[test] - fn test_half() { + fn test_function_half() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("half", 1, true, Some(half)).unwrap(); + db.create_scalar_function("half", 1, true, half).unwrap(); let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); assert_eq!(3f64, result.unwrap()); @@ -310,7 +329,7 @@ mod test { let _: Box = unsafe { Box::from_raw(mem::transmute(raw)) }; } - extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; let need_re = re_ptr.is_null(); @@ -344,9 +363,9 @@ mod test { } #[test] - fn test_regexp() { + fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap(); + db.create_scalar_function("regexp", 2, true, regexp).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::(0));