From 458951e2d575bdfdad951c3ccb5c85a20548b2e4 Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Tue, 15 Dec 2015 20:54:23 +0100 Subject: [PATCH] First draft to support user defined aggregate functions. --- src/functions.rs | 106 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/src/functions.rs b/src/functions.rs index 196e639..7ed24ac 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -341,6 +341,16 @@ impl<'a> Context<'a> { } } +pub trait Aggregate where T: ToResult { + /// Initializes the aggregation context. + fn init(&self) -> A; // TODO Validate: Fn(&Context) + /// "step" function called once for each row in an aggregate group. + fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + /// Computes and sets the final result. + fn finalize(&self, &A) -> Result; +} + + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -384,6 +394,20 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + pub fn create_aggregate_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, deterministic, aggr) + } + /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments @@ -450,6 +474,88 @@ impl InnerConnection { self.decode_result(r) } + fn create_aggregate_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; + if pac.is_null() { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + let ac: *mut A = if (*pac).is_null() { + let a = (*boxed_aggr).init(); + Box::into_raw(Box::new(a)) + } else { + *pac + }; + + let mut ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + (*boxed_aggr).step(&mut ctx, &mut *ac); + } + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; + if pac.is_null() || (*pac).is_null() { + return; + } + let ac: *mut A = *pac; + + match (*boxed_aggr).finalize(&*ac) { + Ok(r) => r.set_result(ctx), + Err(e) => { + ffi::sqlite3_result_error_code(ctx, e.code); + if let Ok(cstr) = str_to_cstring(&e.message) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } + } + + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = try!(str_to_cstring(fn_name)); + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + flags, + mem::transmute(boxed_aggr), + None, + Some(call_boxed_closure::), + Some(call_boxed_final::), + Some(mem::transmute(free_boxed_value::))) + }; + self.decode_result(r) + } + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let c_name = try!(str_to_cstring(fn_name)); let r = unsafe {