From 29494f46f6cf7528caed429048f1a60afe8e7134 Mon Sep 17 00:00:00 2001
From: John Gallagher <jgallagher@bignerdranch.com>
Date: Fri, 11 Dec 2015 12:01:05 -0500
Subject: [PATCH] Let create_scalar_function take an FnMut instead of a extern
 "C" fn.

---
 src/functions.rs | 75 ++++++++++++++++++++++++++++++------------------
 1 file changed, 47 insertions(+), 28 deletions(-)

diff --git a/src/functions.rs b/src/functions.rs
index 52ed7ca..bb386a6 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -1,9 +1,8 @@
 //! Create or redefine SQL functions
 use std::ffi::CStr;
 use std::mem;
-use std::ptr;
 use std::str;
-use libc::{c_int, c_double, c_char};
+use libc::{c_int, c_double, c_char, c_void};
 
 use ffi;
 pub use ffi::sqlite3_context;
@@ -225,30 +224,50 @@ impl<T: FromValue> FromValue for Option<T> {
 // sqlite3_get_auxdata
 // sqlite3_set_auxdata
 
-pub type ScalarFunc = Option<extern "C" fn(ctx: *mut sqlite3_context,
-                                           argc: c_int,
-                                           argv: *mut *mut sqlite3_value)
-                                          >;
+pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {}
+impl<F: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value)> ScalarFunction for F {}
 
 impl SqliteConnection {
-    // TODO pApp
-    pub fn create_scalar_function(&self,
-                                  fn_name: &str,
-                                  n_arg: c_int,
-                                  deterministic: bool,
-                                  x_func: ScalarFunc)
-                                  -> SqliteResult<()> {
+    pub fn create_scalar_function<F>(&self,
+                                     fn_name: &str,
+                                     n_arg: c_int,
+                                     deterministic: bool,
+                                     x_func: F)
+                                     -> SqliteResult<()>
+        where F: ScalarFunction
+    {
         self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func)
     }
 }
 
 impl InnerSqliteConnection {
-    pub fn create_scalar_function(&mut self,
-                                  fn_name: &str,
-                                  n_arg: c_int,
-                                  deterministic: bool,
-                                  x_func: ScalarFunc)
-                                  -> SqliteResult<()> {
+    pub fn create_scalar_function<F>(&mut self,
+                                     fn_name: &str,
+                                     n_arg: c_int,
+                                     deterministic: bool,
+                                     x_func: F)
+                                     -> SqliteResult<()>
+        where F: ScalarFunction
+    {
+        extern "C" fn free_boxed_closure<F>(p: *mut c_void)
+            where F: ScalarFunction
+        {
+            let _: Box<F> = unsafe { Box::from_raw(mem::transmute(p)) };
+        }
+
+        extern "C" fn call_boxed_closure<F>(ctx: *mut sqlite3_context,
+                                            argc: c_int,
+                                            argv: *mut *mut sqlite3_value)
+            where F: ScalarFunction
+        {
+            unsafe {
+                let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx));
+                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
+                (*boxed_f)(ctx, argc, argv);
+            }
+        }
+
+        let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
         let c_name = try!(str_to_cstring(fn_name));
         let mut flags = ffi::SQLITE_UTF8;
         if deterministic {
@@ -259,11 +278,11 @@ impl InnerSqliteConnection {
                                             c_name.as_ptr(),
                                             n_arg,
                                             flags,
-                                            ptr::null_mut(),
-                                            x_func,
+                                            mem::transmute(boxed_f),
+                                            Some(call_boxed_closure::<F>),
                                             None,
                                             None,
-                                            None)
+                                            Some(free_boxed_closure::<F>))
         };
         self.decode_result(r)
     }
@@ -285,7 +304,7 @@ mod test {
     use ffi::sqlite3_value;
     use functions::{FromValue, ToResult};
 
-    extern "C" fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) {
+    fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) {
         unsafe {
             let arg = *argv.offset(0);
             if c_double::parameter_has_valid_sqlite_type(arg) {
@@ -298,9 +317,9 @@ mod test {
     }
 
     #[test]
-    fn test_half() {
+    fn test_function_half() {
         let db = SqliteConnection::open_in_memory().unwrap();
-        db.create_scalar_function("half", 1, true, Some(half)).unwrap();
+        db.create_scalar_function("half", 1, true, half).unwrap();
         let result = db.query_row("SELECT half(6)", &[], |r| r.get::<f64>(0));
 
         assert_eq!(3f64, result.unwrap());
@@ -310,7 +329,7 @@ mod test {
         let _: Box<Regex> = unsafe { Box::from_raw(mem::transmute(raw)) };
     }
 
-    extern "C" fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) {
+    fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) {
         unsafe {
             let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex;
             let need_re = re_ptr.is_null();
@@ -344,9 +363,9 @@ mod test {
     }
 
     #[test]
-    fn test_regexp() {
+    fn test_function_regexp() {
         let db = SqliteConnection::open_in_memory().unwrap();
-        db.create_scalar_function("regexp", 2, true, Some(regexp)).unwrap();
+        db.create_scalar_function("regexp", 2, true, regexp).unwrap();
         let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')",
                                   &[],
                                   |r| r.get::<bool>(0));