From 29494f46f6cf7528caed429048f1a60afe8e7134 Mon Sep 17 00:00:00 2001 From: John Gallagher <jgallagher@bignerdranch.com> Date: Fri, 11 Dec 2015 12:01:05 -0500 Subject: [PATCH] Let create_scalar_function take an FnMut instead of a extern "C" fn. --- src/functions.rs | 75 ++++++++++++++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 28 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 52ed7ca..bb386a6 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,9 +1,8 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; -use std::ptr; use std::str; -use libc::{c_int, c_double, c_char}; +use libc::{c_int, c_double, c_char, c_void}; use ffi; pub use ffi::sqlite3_context; @@ -225,30 +224,50 @@ impl<T: FromValue> FromValue for Option<T> { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub type ScalarFunc = Option<extern "C" fn(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) - >; +pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {} +impl<F: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value)> ScalarFunction for F {} impl SqliteConnection { - // TODO pApp - pub fn create_scalar_function(&self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function<F>(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: ScalarFunc) - -> SqliteResult<()> { + pub fn create_scalar_function<F>(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: ScalarFunction + { + extern "C" fn free_boxed_closure<F>(p: *mut c_void) + where F: ScalarFunction + { + let _: Box<F> = unsafe { Box::from_raw(mem::transmute(p)) }; + } + + extern "C" fn call_boxed_closure<F>(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where F: ScalarFunction + { + unsafe { + let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_f.is_null(), "Internal error - null function pointer"); + (*boxed_f)(ctx, argc, argv); + } + } + + let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let c_name = try!(str_to_cstring(fn_name)); let mut flags = ffi::SQLITE_UTF8; if deterministic { @@ -259,11 +278,11 @@ impl InnerSqliteConnection { c_name.as_ptr(), n_arg, flags, - ptr::null_mut(), - x_func, + mem::transmute(boxed_f), + Some(call_boxed_closure::<F>), None, None, - None) + Some(free_boxed_closure::<F>)) }; self.decode_result(r) } @@ -285,7 +304,7 @@ mod test { use ffi::sqlite3_value; use functions::{FromValue, ToResult}; - extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { @@ -298,9 +317,9 @@ mod test { } #[test] - fn test_half() { + fn test_function_half() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("half", 1, true, Some(half)).unwrap(); + db.create_scalar_function("half", 1, true, half).unwrap(); let result = db.query_row("SELECT half(6)", &[], |r| r.get::<f64>(0)); assert_eq!(3f64, result.unwrap()); @@ -310,7 +329,7 @@ mod test { let _: Box<Regex> = unsafe { Box::from_raw(mem::transmute(raw)) }; } - extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; let need_re = re_ptr.is_null(); @@ -344,9 +363,9 @@ mod test { } #[test] - fn test_regexp() { + fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); - db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap(); + db.create_scalar_function("regexp", 2, true, regexp).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::<bool>(0));