diff --git a/src/functions.rs b/src/functions.rs index bb386a6..4c024b2 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -224,8 +224,37 @@ impl FromValue for Option { // sqlite3_get_auxdata // sqlite3_set_auxdata -pub trait ScalarFunction: FnMut(*mut sqlite3_context, c_int, *mut *mut sqlite3_value) {} -impl ScalarFunction for F {} +unsafe extern "C" fn free_boxed_value(p: *mut c_void) { + let _: Box = Box::from_raw(mem::transmute(p)); +} + +pub struct Context { + pub ctx: *mut sqlite3_context, +} + +impl Context { + pub fn set_aux(&self, arg: c_int, value: T) { + let boxed = Box::into_raw(Box::new(value)); + unsafe { + ffi::sqlite3_set_auxdata(self.ctx, + arg, + mem::transmute(boxed), + Some(mem::transmute(free_boxed_value::))) + }; + } + + pub unsafe fn get_aux(&self, arg: c_int) -> Option<&T> { + let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T; + if p.is_null() { + None + } else { + Some(&*p) + } + } +} + +pub trait ScalarFunction: FnMut(&Context, c_int, *mut *mut sqlite3_value) {} +impl ScalarFunction for F {} impl SqliteConnection { pub fn create_scalar_function(&self, @@ -249,21 +278,16 @@ impl InnerSqliteConnection { -> SqliteResult<()> where F: ScalarFunction { - extern "C" fn free_boxed_closure(p: *mut c_void) - where F: ScalarFunction - { - let _: Box = unsafe { Box::from_raw(mem::transmute(p)) }; - } - extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) where F: ScalarFunction { + let ctx = Context { ctx: ctx }; unsafe { - let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx)); + let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - (*boxed_f)(ctx, argc, argv); + (*boxed_f)(&ctx, argc, argv); } } @@ -282,7 +306,7 @@ impl InnerSqliteConnection { Some(call_boxed_closure::), None, None, - Some(free_boxed_closure::)) + Some(mem::transmute(free_boxed_value::))) }; self.decode_result(r) } @@ -292,26 +316,23 @@ impl InnerSqliteConnection { mod test { extern crate regex; - use std::boxed::Box; use std::ffi::CString; - use std::mem; - use libc::{c_int, c_double, c_void}; + use libc::{c_int, c_double}; use self::regex::Regex; use SqliteConnection; use ffi; - use ffi::sqlite3_context; use ffi::sqlite3_value; - use functions::{FromValue, ToResult}; + use functions::{Context, FromValue, ToResult}; - fn half(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { + fn half(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { unsafe { let arg = *argv.offset(0); if c_double::parameter_has_valid_sqlite_type(arg) { let value = c_double::parameter_value(arg).unwrap() / 2f64; - value.set_result(ctx); + value.set_result(ctx.ctx); } else { - ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISMATCH); + ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_MISMATCH); } } } @@ -325,40 +346,47 @@ mod test { assert_eq!(3f64, result.unwrap()); } - extern "C" fn regexp_free(raw: *mut c_void) { - let _: Box = unsafe { Box::from_raw(mem::transmute(raw)) }; - } - - fn regexp(ctx: *mut sqlite3_context, _: c_int, argv: *mut *mut sqlite3_value) { - unsafe { - let mut re_ptr = ffi::sqlite3_get_auxdata(ctx, 0) as *const Regex; - let need_re = re_ptr.is_null(); - if need_re { + fn regexp(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { + let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; + let new_re = match saved_re { + None => unsafe { let raw = String::parameter_value(*argv.offset(0)); if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); return; } let comp = Regex::new(raw.unwrap().as_ref()); if comp.is_err() { let msg = CString::new(format!("{}", comp.unwrap_err())).unwrap(); - ffi::sqlite3_result_error(ctx, msg.as_ptr(), -1); + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); return; } - let re = Box::new(comp.unwrap()); - re_ptr = Box::into_raw(re); - } + Some(comp.unwrap()) + }, + Some(_) => None, + }; - let text = String::parameter_value(*argv.offset(1)); + { + let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); + + let text = unsafe { String::parameter_value(*argv.offset(1)) }; if text.is_ok() { let text = text.unwrap(); - (*re_ptr).is_match(text.as_ref()).set_result(ctx); + unsafe { + re.is_match(text.as_ref()).set_result(ctx.ctx); + } + } else { + let msg = CString::new(format!("{}", text.unwrap_err())).unwrap(); + unsafe { + ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); + } + return; } + } - if need_re { - ffi::sqlite3_set_auxdata(ctx, 0, mem::transmute(re_ptr), Some(regexp_free)); - } + if let Some(re) = new_re { + ctx.set_aux(0, re); } }