Add get to function::Context.

This allows user-defined functions to now only accept a `Context`, as it
embeds the arguments inside itself.
This commit is contained in:
John Gallagher 2015-12-11 14:46:28 -05:00
parent 94d40c41c7
commit 81ec7fe7cd

View File

@ -1,6 +1,7 @@
//! Create or redefine SQL functions
use std::ffi::CStr;
use std::mem;
use std::slice;
use std::str;
use libc::{c_int, c_double, c_char, c_void};
@ -228,11 +229,30 @@ unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
let _: Box<T> = Box::from_raw(mem::transmute(p));
}
pub struct Context {
pub struct Context<'a> {
pub ctx: *mut sqlite3_context,
args: &'a [*mut sqlite3_value],
}
impl Context {
impl<'a> Context<'a> {
pub fn len(&self) -> usize {
self.args.len()
}
pub fn get<T: FromValue>(&self, idx: usize) -> SqliteResult<T> {
let arg = self.args[idx];
unsafe {
if T::parameter_has_valid_sqlite_type(arg) {
T::parameter_value(arg)
} else {
Err(SqliteError {
code: ffi::SQLITE_MISMATCH,
message: "Invalid value type".to_string(),
})
}
}
}
pub fn set_aux<T>(&self, arg: c_int, value: T) {
let boxed = Box::into_raw(Box::new(value));
unsafe {
@ -253,9 +273,6 @@ impl Context {
}
}
pub trait ScalarFunction: FnMut(&Context, c_int, *mut *mut sqlite3_value) {}
impl<F: FnMut(&Context, c_int, *mut *mut sqlite3_value)> ScalarFunction for F {}
impl SqliteConnection {
pub fn create_scalar_function<F>(&self,
fn_name: &str,
@ -263,31 +280,34 @@ impl SqliteConnection {
deterministic: bool,
x_func: F)
-> SqliteResult<()>
where F: ScalarFunction
where F: FnMut(&Context)
{
self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func)
}
}
impl InnerSqliteConnection {
pub fn create_scalar_function<F>(&mut self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
x_func: F)
-> SqliteResult<()>
where F: ScalarFunction
fn create_scalar_function<F>(&mut self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
x_func: F)
-> SqliteResult<()>
where F: FnMut(&Context)
{
extern "C" fn call_boxed_closure<F>(ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value)
where F: ScalarFunction
where F: FnMut(&Context)
{
let ctx = Context { ctx: ctx };
unsafe {
let ctx = Context {
ctx: ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx));
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
(*boxed_f)(&ctx, argc, argv);
(*boxed_f)(&ctx);
}
}
@ -317,23 +337,18 @@ mod test {
extern crate regex;
use std::ffi::CString;
use libc::{c_int, c_double};
use libc::c_double;
use self::regex::Regex;
use SqliteConnection;
use ffi;
use ffi::sqlite3_value;
use functions::{Context, FromValue, ToResult};
use functions::{Context, ToResult};
fn half(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) {
unsafe {
let arg = *argv.offset(0);
if c_double::parameter_has_valid_sqlite_type(arg) {
let value = c_double::parameter_value(arg).unwrap() / 2f64;
value.set_result(ctx.ctx);
} else {
ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_MISMATCH);
}
fn half(ctx: &Context) {
assert!(ctx.len() == 1, "called with unexpected number of arguments");
match ctx.get::<c_double>(0) {
Ok(value) => unsafe { (value / 2f64).set_result(ctx.ctx) },
Err(err) => unsafe { ffi::sqlite3_result_error_code(ctx.ctx, err.code) },
}
}
@ -346,11 +361,13 @@ mod test {
assert_eq!(3f64, result.unwrap());
}
fn regexp(ctx: &Context, _: c_int, argv: *mut *mut sqlite3_value) {
fn regexp(ctx: &Context) {
assert!(ctx.len() == 2, "called with unexpected number of arguments");
let saved_re: Option<&Regex> = unsafe { ctx.get_aux(0) };
let new_re = match saved_re {
None => unsafe {
let raw = String::parameter_value(*argv.offset(0));
let raw = ctx.get::<String>(0);
if raw.is_err() {
let msg = CString::new(format!("{}", raw.unwrap_err())).unwrap();
ffi::sqlite3_result_error(ctx.ctx, msg.as_ptr(), -1);
@ -370,7 +387,7 @@ mod test {
{
let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
let text = unsafe { String::parameter_value(*argv.offset(1)) };
let text = ctx.get::<String>(1);
if text.is_ok() {
let text = text.unwrap();
unsafe {
@ -391,13 +408,27 @@ mod test {
}
#[test]
#[cfg_attr(rustfmt, rustfmt_skip)]
fn test_function_regexp() {
let db = SqliteConnection::open_in_memory().unwrap();
db.execute_batch("BEGIN;
CREATE TABLE foo (x string);
INSERT INTO foo VALUES ('lisa');
INSERT INTO foo VALUES ('lXsi');
INSERT INTO foo VALUES ('lisX');
END;").unwrap();
db.create_scalar_function("regexp", 2, true, regexp).unwrap();
let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')",
&[],
|r| r.get::<bool>(0));
assert_eq!(true, result.unwrap());
let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
&[],
|r| r.get::<i64>(0));
assert_eq!(2, result.unwrap());
}
}