From 458951e2d575bdfdad951c3ccb5c85a20548b2e4 Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Tue, 15 Dec 2015 20:54:23 +0100 Subject: [PATCH 1/4] First draft to support user defined aggregate functions. --- src/functions.rs | 106 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/src/functions.rs b/src/functions.rs index 196e639..7ed24ac 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -341,6 +341,16 @@ impl<'a> Context<'a> { } } +pub trait Aggregate where T: ToResult { + /// Initializes the aggregation context. + fn init(&self) -> A; // TODO Validate: Fn(&Context) + /// "step" function called once for each row in an aggregate group. + fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + /// Computes and sets the final result. + fn finalize(&self, &A) -> Result; +} + + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -384,6 +394,20 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + pub fn create_aggregate_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, deterministic, aggr) + } + /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments @@ -450,6 +474,88 @@ impl InnerConnection { self.decode_result(r) } + fn create_aggregate_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; + if pac.is_null() { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + let ac: *mut A = if (*pac).is_null() { + let a = (*boxed_aggr).init(); + Box::into_raw(Box::new(a)) + } else { + *pac + }; + + let mut ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + (*boxed_aggr).step(&mut ctx, &mut *ac); + } + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; + if pac.is_null() || (*pac).is_null() { + return; + } + let ac: *mut A = *pac; + + match (*boxed_aggr).finalize(&*ac) { + Ok(r) => r.set_result(ctx), + Err(e) => { + ffi::sqlite3_result_error_code(ctx, e.code); + if let Ok(cstr) = str_to_cstring(&e.message) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } + } + + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = try!(str_to_cstring(fn_name)); + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + flags, + mem::transmute(boxed_aggr), + None, + Some(call_boxed_closure::), + Some(call_boxed_final::), + Some(mem::transmute(free_boxed_value::))) + }; + self.decode_result(r) + } + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let c_name = try!(str_to_cstring(fn_name)); let r = unsafe { From 13c93e0f8bed66fd3b96b0d45d46058d1bb83fcc Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Tue, 15 Dec 2015 20:57:32 +0100 Subject: [PATCH 2/4] Rustfmt --- src/functions.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 7ed24ac..9460b2c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -400,7 +400,7 @@ impl Connection { deterministic: bool, aggr: D) -> Result<()> - where D: Aggregate, + where D: Aggregate, T: ToResult { self.db @@ -480,17 +480,18 @@ impl InnerConnection { deterministic: bool, aggr: D) -> Result<()> - where D: Aggregate, + where D: Aggregate, T: ToResult { unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) - where D: Aggregate, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, T: ToResult { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); - assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; @@ -513,11 +514,12 @@ impl InnerConnection { (*boxed_aggr).step(&mut ctx, &mut *ac); } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) - where D: Aggregate, + where D: Aggregate, T: ToResult { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); - assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; if pac.is_null() || (*pac).is_null() { From 83b9fd0abaf82459eae137b33fff5a8fca4cca4b Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 20 Dec 2015 12:23:51 +0100 Subject: [PATCH 3/4] Test a user-defined aggregate function: my_sum. --- src/functions.rs | 80 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 10 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index e2048c2..a36b6fa 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -336,7 +336,7 @@ pub trait Aggregate where T: ToResult { /// Initializes the aggregation context. fn init(&self) -> A; // TODO Validate: Fn(&Context) /// "step" function called once for each row in an aggregate group. - fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A /// Computes and sets the final result. fn finalize(&self, &A) -> Result; } @@ -441,7 +441,7 @@ impl InnerConnection { if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } - }, + } Err(err) => { ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); if let Ok(cstr) = str_to_cstring(err.description()) { @@ -498,7 +498,8 @@ impl InnerConnection { } let ac: *mut A = if (*pac).is_null() { let a = (*boxed_aggr).init(); - Box::into_raw(Box::new(a)) + *pac = Box::into_raw(Box::new(a)); + *pac } else { *pac }; @@ -508,7 +509,21 @@ impl InnerConnection { args: slice::from_raw_parts(argv, argc as usize), }; - (*boxed_aggr).step(&mut ctx, &mut *ac); + match (*boxed_aggr).step(&mut ctx, &mut *ac) { + Ok(_) => {} + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + }; } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, @@ -523,16 +538,23 @@ impl InnerConnection { return; } let ac: *mut A = *pac; + let a = Box::from_raw(mem::transmute(ac)); // to be freed - match (*boxed_aggr).finalize(&*ac) { + match (*boxed_aggr).finalize(&a) { Ok(r) => r.set_result(ctx), - Err(e) => { - ffi::sqlite3_result_error_code(ctx, e.code); - if let Ok(cstr) = str_to_cstring(&e.message) { + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } } - } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; } let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); @@ -581,7 +603,7 @@ mod test { use self::regex::Regex; use {Connection, Error, Result}; - use functions::Context; + use functions::{Aggregate, Context}; fn half(ctx: &Context) -> Result { assert!(ctx.len() == 1, "called with unexpected number of arguments"); @@ -739,4 +761,42 @@ mod test { assert_eq!(expected, result); } } + + struct Sum; + + impl Aggregate for Sum { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum = *sum + try!(ctx.get::(0)); + Ok(()) + } + + fn finalize(&self, sum: &i64) -> Result { + Ok(*sum) + } + } + + #[test] + fn test_sum() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option = db.query_row(no_result, &[], |r| r.get(0)) + .unwrap(); + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(4, result); + + let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ + SELECT 2, 1)"; + let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) + .unwrap(); + assert_eq!((4, 2), result); + } } From 987b06cf7942b92768613fc419dbe8ea6ffc9ea6 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 20 Dec 2015 19:27:28 +0100 Subject: [PATCH 4/4] Add some documentation --- src/functions.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index a36b6fa..0772b22 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -332,12 +332,16 @@ impl<'a> Context<'a> { } } +/// Aggregate is the callback interface for user-defined aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final result. +/// Implementations should be stateless. pub trait Aggregate where T: ToResult { /// Initializes the aggregation context. - fn init(&self) -> A; // TODO Validate: Fn(&Context) + fn init(&self) -> A; /// "step" function called once for each row in an aggregate group. - fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A - /// Computes and sets the final result. + fn step(&self, &mut Context, &mut A) -> Result<()>; + /// Computes and returns the final result. fn finalize(&self, &A) -> Result; } @@ -385,6 +389,11 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + /// Attach a user-defined aggregate function to this database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. pub fn create_aggregate_function(&self, fn_name: &str, n_arg: c_int, @@ -402,7 +411,7 @@ impl Connection { /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments - /// given to `create_scalar_function`. + /// given to `create_scalar_function` or `create_aggregate_function`. /// /// # Failure /// @@ -793,9 +802,9 @@ mod test { .unwrap(); assert_eq!(4, result); - let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ - SELECT 2, 1)"; - let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db.query_row(dual_sum, &[], |r| (r.get(0), r.get(1))) .unwrap(); assert_eq!((4, 2), result); }