From 81ec7fe7cd12bf2bb28e07170984df0adcdfa47e Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Fri, 11 Dec 2015 14:46:28 -0500 Subject: [PATCH] Add `get` to function::Context. This allows user-defined functions to now only accept a `Context`, as it embeds the arguments inside itself. --- src/functions.rs | 93 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 4c024b2..6481b8a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -1,6 +1,7 @@ //! Create or redefine SQL functions use std::ffi::CStr; use std::mem; +use std::slice; use std::str; use libc::{c_int, c_double, c_char, c_void}; @@ -228,11 +229,30 @@ unsafe extern "C" fn free_boxed_value(p: *mut c_void) { let _: Box = Box::from_raw(mem::transmute(p)); } -pub struct Context { +pub struct Context<'a> { pub ctx: *mut sqlite3_context, + args: &'a [*mut sqlite3_value], } -impl Context { +impl<'a> Context<'a> { + pub fn len(&self) -> usize { + self.args.len() + } + + pub fn get(&self, idx: usize) -> SqliteResult { + let arg = self.args[idx]; + unsafe { + if T::parameter_has_valid_sqlite_type(arg) { + T::parameter_value(arg) + } else { + Err(SqliteError { + code: ffi::SQLITE_MISMATCH, + message: "Invalid value type".to_string(), + }) + } + } + } + pub fn set_aux(&self, arg: c_int, value: T) { let boxed = Box::into_raw(Box::new(value)); unsafe { @@ -253,9 +273,6 @@ impl Context { } } -pub trait ScalarFunction: FnMut(&Context, c_int, *mut *mut sqlite3_value) {} -impl ScalarFunction for F {} - impl SqliteConnection { pub fn create_scalar_function(&self, fn_name: &str, @@ -263,31 +280,34 @@ impl SqliteConnection { deterministic: bool, x_func: F) -> SqliteResult<()> - where F: ScalarFunction + where F: FnMut(&Context) { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { - pub fn create_scalar_function(&mut self, - fn_name: &str, - n_arg: c_int, - deterministic: bool, - x_func: F) - -> SqliteResult<()> - where F: ScalarFunction + fn create_scalar_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + x_func: F) + -> SqliteResult<()> + where F: FnMut(&Context) { extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) - where F: ScalarFunction + where F: FnMut(&Context) { - let ctx = Context { ctx: ctx }; unsafe { + let ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - (*boxed_f)(&ctx, argc, argv); + (*boxed_f)(&ctx); } } @@ -317,23 +337,18 @@ mod test { extern crate regex; use std::ffi::CString; - use libc::{c_int, c_double}; + use libc::c_double; use self::regex::Regex; use SqliteConnection; use ffi; - use ffi::sqlite3_value; - use functions::{Context, FromValue, ToResult}; + use functions::{Context, ToResult}; - fn half(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { - unsafe { - let arg = *argv.offset(0); - if c_double::parameter_has_valid_sqlite_type(arg) { - let value = c_double::parameter_value(arg).unwrap() / 2f64; - value.set_result(ctx.ctx); - } else { - ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_MISMATCH); - } + fn half(ctx: &Context) { + assert!(ctx.len() == 1, "called with unexpected number of arguments"); + match ctx.get::(0) { + Ok(value) => unsafe { (value / 2f64).set_result(ctx.ctx) }, + Err(err) => unsafe { ffi::sqlite3_result_error_code(ctx.ctx, err.code) }, } } @@ -346,11 +361,13 @@ mod test { assert_eq!(3f64, result.unwrap()); } - fn regexp(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) { + fn regexp(ctx: &Context) { + assert!(ctx.len() == 2, "called with unexpected number of arguments"); + let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; let new_re = match saved_re { None => unsafe { - let raw = String::parameter_value(*argv.offset(0)); + let raw = ctx.get::(0); if raw.is_err() { let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap(); ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1); @@ -370,7 +387,7 @@ mod test { { let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); - let text = unsafe { String::parameter_value(*argv.offset(1)) }; + let text = ctx.get::(1); if text.is_ok() { let text = text.unwrap(); unsafe { @@ -391,13 +408,27 @@ mod test { } #[test] + #[cfg_attr(rustfmt, rustfmt_skip)] fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); + db.execute_batch("BEGIN; + CREATE TABLE foo (x string); + INSERT INTO foo VALUES ('lisa'); + INSERT INTO foo VALUES ('lXsi'); + INSERT INTO foo VALUES ('lisX'); + END;").unwrap(); db.create_scalar_function("regexp", 2, true, regexp).unwrap(); + let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::(0)); assert_eq!(true, result.unwrap()); + + let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + &[], + |r| r.get::(0)); + + assert_eq!(2, result.unwrap()); } }