//! Create or redefine SQL functions use std::ffi::CStr; use std::mem; use std::slice; use std::str; use libc::{c_int, c_double, c_char, c_void}; use ffi; pub use ffi::sqlite3_context; pub use ffi::sqlite3_value; pub use ffi::sqlite3_value_type; pub use ffi::sqlite3_value_numeric_type; use types::Null; use {SqliteResult, SqliteError, SqliteConnection, str_to_cstring, InnerSqliteConnection}; /// A trait for types that can be converted into the result of an SQL function. pub trait ToResult { unsafe fn set_result(&self, ctx: *mut sqlite3_context); } macro_rules! raw_to_impl( ($t:ty, $f:ident) => ( impl ToResult for $t { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { ffi::$f(ctx, *self) } } ) ); raw_to_impl!(c_int, sqlite3_result_int); raw_to_impl!(i64, sqlite3_result_int64); raw_to_impl!(c_double, sqlite3_result_double); impl<'a> ToResult for bool { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { match *self { true => ffi::sqlite3_result_int(ctx, 1), _ => ffi::sqlite3_result_int(ctx, 0), } } } impl<'a> ToResult for &'a str { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { let length = self.len(); if length > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); return; } match str_to_cstring(self) { Ok(c_str) => { ffi::sqlite3_result_text(ctx, c_str.as_ptr(), length as c_int, ffi::SQLITE_TRANSIENT()) } Err(_) => ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), // TODO sqlite3_result_error } } } impl ToResult for String { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { (&self[..]).set_result(ctx) } } impl<'a> ToResult for &'a [u8] { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { if self.len() > ::std::i32::MAX as usize { ffi::sqlite3_result_error_toobig(ctx); return; } ffi::sqlite3_result_blob(ctx, mem::transmute(self.as_ptr()), self.len() as c_int, ffi::SQLITE_TRANSIENT()) } } impl ToResult for Vec { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { (&self[..]).set_result(ctx) } } impl ToResult for Option { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { match *self { None => ffi::sqlite3_result_null(ctx), Some(ref t) => t.set_result(ctx), } } } impl ToResult for Null { unsafe fn set_result(&self, ctx: *mut sqlite3_context) { ffi::sqlite3_result_null(ctx) } } // sqlite3_result_error_code, c_int // sqlite3_result_error_nomem // sqlite3_result_error_toobig // sqlite3_result_error, *const c_char, c_int // sqlite3_result_zeroblob // sqlite3_result_value /// A trait for types that can be created from a SQLite function parameter value. pub trait FromValue: Sized { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult; /// FromValue types can implement this method and use sqlite3_value_type to check that /// the type reported by SQLite matches a type suitable for Self. This method is used /// by `???` to confirm that the parameter contains a valid type before /// attempting to retrieve the value. unsafe fn parameter_has_valid_sqlite_type(_: *mut sqlite3_value) -> bool { true } } macro_rules! raw_from_impl( ($t:ty, $f:ident, $c:expr) => ( impl FromValue for $t { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult<$t> { Ok(ffi::$f(v)) } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_numeric_type(v) == $c } } ) ); raw_from_impl!(c_int, sqlite3_value_int, ffi::SQLITE_INTEGER); raw_from_impl!(i64, sqlite3_value_int64, ffi::SQLITE_INTEGER); impl FromValue for bool { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { match ffi::sqlite3_value_int(v) { 0 => Ok(false), _ => Ok(true), } } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER } } impl FromValue for c_double { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { Ok(ffi::sqlite3_value_double(v)) } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_numeric_type(v) == ffi::SQLITE_FLOAT || sqlite3_value_numeric_type(v) == ffi::SQLITE_INTEGER } } impl FromValue for String { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult { let c_text = ffi::sqlite3_value_text(v); if c_text.is_null() { Ok("".to_string()) } else { let c_slice = CStr::from_ptr(c_text as *const c_char).to_bytes(); let utf8_str = str::from_utf8(c_slice); utf8_str.map(|s| s.to_string()) .map_err(|e| { SqliteError { code: 0, message: e.to_string(), } }) } } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_type(v) == ffi::SQLITE_TEXT } } impl FromValue for Vec { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult> { use std::slice::from_raw_parts; let c_blob = ffi::sqlite3_value_blob(v); let len = ffi::sqlite3_value_bytes(v); assert!(len >= 0, "unexpected negative return from sqlite3_value_bytes"); let len = len as usize; Ok(from_raw_parts(mem::transmute(c_blob), len).to_vec()) } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_type(v) == ffi::SQLITE_BLOB } } impl FromValue for Option { unsafe fn parameter_value(v: *mut sqlite3_value) -> SqliteResult> { if sqlite3_value_type(v) == ffi::SQLITE_NULL { Ok(None) } else { FromValue::parameter_value(v).map(|t| Some(t)) } } unsafe fn parameter_has_valid_sqlite_type(v: *mut sqlite3_value) -> bool { sqlite3_value_type(v) == ffi::SQLITE_NULL || T::parameter_has_valid_sqlite_type(v) } } unsafe extern "C" fn free_boxed_value(p: *mut c_void) { let _: Box = Box::from_raw(mem::transmute(p)); } pub struct Context<'a> { ctx: *mut sqlite3_context, args: &'a [*mut sqlite3_value], } impl<'a> Context<'a> { pub fn len(&self) -> usize { self.args.len() } pub fn get(&self, idx: usize) -> SqliteResult { let arg = self.args[idx]; unsafe { if T::parameter_has_valid_sqlite_type(arg) { T::parameter_value(arg) } else { Err(SqliteError { code: ffi::SQLITE_MISMATCH, message: "Invalid value type".to_string(), }) } } } pub fn set_aux(&self, arg: c_int, value: T) { let boxed = Box::into_raw(Box::new(value)); unsafe { ffi::sqlite3_set_auxdata(self.ctx, arg, mem::transmute(boxed), Some(mem::transmute(free_boxed_value::))) }; } pub unsafe fn get_aux(&self, arg: c_int) -> Option<&T> { let p = ffi::sqlite3_get_auxdata(self.ctx, arg) as *mut T; if p.is_null() { None } else { Some(&*p) } } } impl SqliteConnection { pub fn create_scalar_function(&self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: F) -> SqliteResult<()> where F: FnMut(&Context) -> SqliteResult, T: ToResult { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } } impl InnerSqliteConnection { fn create_scalar_function(&mut self, fn_name: &str, n_arg: c_int, deterministic: bool, x_func: F) -> SqliteResult<()> where F: FnMut(&Context) -> SqliteResult, T: ToResult { extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) where F: FnMut(&Context) -> SqliteResult, T: ToResult { unsafe { let ctx = Context { ctx: ctx, args: slice::from_raw_parts(argv, argc as usize), }; let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); match (*boxed_f)(&ctx) { Ok(r) => r.set_result(ctx.ctx), Err(e) => { ffi::sqlite3_result_error_code(ctx.ctx, e.code); if let Ok(cstr) = str_to_cstring(&e.message) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } }, } } } let boxed_f: *mut F = Box::into_raw(Box::new(x_func)); let c_name = try!(str_to_cstring(fn_name)); let mut flags = ffi::SQLITE_UTF8; if deterministic { flags |= ffi::SQLITE_DETERMINISTIC; } let r = unsafe { ffi::sqlite3_create_function_v2(self.db(), c_name.as_ptr(), n_arg, flags, mem::transmute(boxed_f), Some(call_boxed_closure::), None, None, Some(mem::transmute(free_boxed_value::))) }; self.decode_result(r) } } #[cfg(test)] mod test { extern crate regex; use libc::c_double; use self::regex::Regex; use {SqliteConnection, SqliteError, SqliteResult}; use ffi; use functions::Context; fn half(ctx: &Context) -> SqliteResult { assert!(ctx.len() == 1, "called with unexpected number of arguments"); let value = try!(ctx.get::(0)); Ok(value / 2f64) } #[test] fn test_function_half() { let db = SqliteConnection::open_in_memory().unwrap(); db.create_scalar_function("half", 1, true, half).unwrap(); let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); assert_eq!(3f64, result.unwrap()); } fn regexp(ctx: &Context) -> SqliteResult { assert!(ctx.len() == 2, "called with unexpected number of arguments"); let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) }; let new_re = match saved_re { None => { let s = try!(ctx.get::(0)); let r = try!(Regex::new(&s).map_err(|e| SqliteError { code: ffi::SQLITE_ERROR, message: format!("Invalid regular expression: {}", e), })); Some(r) }, Some(_) => None, }; let is_match = { let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap()); let text = try!(ctx.get::(1)); re.is_match(&text) }; if let Some(re) = new_re { ctx.set_aux(0, re); } Ok(is_match) } #[test] #[cfg_attr(rustfmt, rustfmt_skip)] fn test_function_regexp() { let db = SqliteConnection::open_in_memory().unwrap(); db.execute_batch("BEGIN; CREATE TABLE foo (x string); INSERT INTO foo VALUES ('lisa'); INSERT INTO foo VALUES ('lXsi'); INSERT INTO foo VALUES ('lisX'); END;").unwrap(); db.create_scalar_function("regexp", 2, true, regexp).unwrap(); let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], |r| r.get::(0)); assert_eq!(true, result.unwrap()); let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", &[], |r| r.get::(0)); assert_eq!(2, result.unwrap()); } }