From 458951e2d575bdfdad951c3ccb5c85a20548b2e4 Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Tue, 15 Dec 2015 20:54:23 +0100 Subject: [PATCH 01/11] First draft to support user defined aggregate functions. --- src/functions.rs | 106 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/src/functions.rs b/src/functions.rs index 196e639..7ed24ac 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -341,6 +341,16 @@ impl<'a> Context<'a> { } } +pub trait Aggregate where T: ToResult { + /// Initializes the aggregation context. + fn init(&self) -> A; // TODO Validate: Fn(&Context) + /// "step" function called once for each row in an aggregate group. + fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + /// Computes and sets the final result. + fn finalize(&self, &A) -> Result; +} + + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -384,6 +394,20 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + pub fn create_aggregate_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, deterministic, aggr) + } + /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments @@ -450,6 +474,88 @@ impl InnerConnection { self.decode_result(r) } + fn create_aggregate_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; + if pac.is_null() { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + let ac: *mut A = if (*pac).is_null() { + let a = (*boxed_aggr).init(); + Box::into_raw(Box::new(a)) + } else { + *pac + }; + + let mut ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + (*boxed_aggr).step(&mut ctx, &mut *ac); + } + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + + let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; + if pac.is_null() || (*pac).is_null() { + return; + } + let ac: *mut A = *pac; + + match (*boxed_aggr).finalize(&*ac) { + Ok(r) => r.set_result(ctx), + Err(e) => { + ffi::sqlite3_result_error_code(ctx, e.code); + if let Ok(cstr) = str_to_cstring(&e.message) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } + } + + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = try!(str_to_cstring(fn_name)); + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + flags, + mem::transmute(boxed_aggr), + None, + Some(call_boxed_closure::), + Some(call_boxed_final::), + Some(mem::transmute(free_boxed_value::))) + }; + self.decode_result(r) + } + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let c_name = try!(str_to_cstring(fn_name)); let r = unsafe { From 13c93e0f8bed66fd3b96b0d45d46058d1bb83fcc Mon Sep 17 00:00:00 2001 From: Gwenael Treguier Date: Tue, 15 Dec 2015 20:57:32 +0100 Subject: [PATCH 02/11] Rustfmt --- src/functions.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 7ed24ac..9460b2c 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -400,7 +400,7 @@ impl Connection { deterministic: bool, aggr: D) -> Result<()> - where D: Aggregate, + where D: Aggregate, T: ToResult { self.db @@ -480,17 +480,18 @@ impl InnerConnection { deterministic: bool, aggr: D) -> Result<()> - where D: Aggregate, + where D: Aggregate, T: ToResult { unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) - where D: Aggregate, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, T: ToResult { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); - assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; @@ -513,11 +514,12 @@ impl InnerConnection { (*boxed_aggr).step(&mut ctx, &mut *ac); } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) - where D: Aggregate, + where D: Aggregate, T: ToResult { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); - assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; if pac.is_null() || (*pac).is_null() { From 83b9fd0abaf82459eae137b33fff5a8fca4cca4b Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 20 Dec 2015 12:23:51 +0100 Subject: [PATCH 03/11] Test a user-defined aggregate function: my_sum. --- src/functions.rs | 80 ++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 70 insertions(+), 10 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index e2048c2..a36b6fa 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -336,7 +336,7 @@ pub trait Aggregate where T: ToResult { /// Initializes the aggregation context. fn init(&self) -> A; // TODO Validate: Fn(&Context) /// "step" function called once for each row in an aggregate group. - fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A /// Computes and sets the final result. fn finalize(&self, &A) -> Result; } @@ -441,7 +441,7 @@ impl InnerConnection { if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } - }, + } Err(err) => { ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); if let Ok(cstr) = str_to_cstring(err.description()) { @@ -498,7 +498,8 @@ impl InnerConnection { } let ac: *mut A = if (*pac).is_null() { let a = (*boxed_aggr).init(); - Box::into_raw(Box::new(a)) + *pac = Box::into_raw(Box::new(a)); + *pac } else { *pac }; @@ -508,7 +509,21 @@ impl InnerConnection { args: slice::from_raw_parts(argv, argc as usize), }; - (*boxed_aggr).step(&mut ctx, &mut *ac); + match (*boxed_aggr).step(&mut ctx, &mut *ac) { + Ok(_) => {} + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + }; } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, @@ -523,16 +538,23 @@ impl InnerConnection { return; } let ac: *mut A = *pac; + let a = Box::from_raw(mem::transmute(ac)); // to be freed - match (*boxed_aggr).finalize(&*ac) { + match (*boxed_aggr).finalize(&a) { Ok(r) => r.set_result(ctx), - Err(e) => { - ffi::sqlite3_result_error_code(ctx, e.code); - if let Ok(cstr) = str_to_cstring(&e.message) { + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } } - } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; } let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); @@ -581,7 +603,7 @@ mod test { use self::regex::Regex; use {Connection, Error, Result}; - use functions::Context; + use functions::{Aggregate, Context}; fn half(ctx: &Context) -> Result { assert!(ctx.len() == 1, "called with unexpected number of arguments"); @@ -739,4 +761,42 @@ mod test { assert_eq!(expected, result); } } + + struct Sum; + + impl Aggregate for Sum { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum = *sum + try!(ctx.get::(0)); + Ok(()) + } + + fn finalize(&self, sum: &i64) -> Result { + Ok(*sum) + } + } + + #[test] + fn test_sum() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option = db.query_row(no_result, &[], |r| r.get(0)) + .unwrap(); + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(4, result); + + let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ + SELECT 2, 1)"; + let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) + .unwrap(); + assert_eq!((4, 2), result); + } } From 987b06cf7942b92768613fc419dbe8ea6ffc9ea6 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 20 Dec 2015 19:27:28 +0100 Subject: [PATCH 04/11] Add some documentation --- src/functions.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index a36b6fa..0772b22 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -332,12 +332,16 @@ impl<'a> Context<'a> { } } +/// Aggregate is the callback interface for user-defined aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final result. +/// Implementations should be stateless. pub trait Aggregate where T: ToResult { /// Initializes the aggregation context. - fn init(&self) -> A; // TODO Validate: Fn(&Context) + fn init(&self) -> A; /// "step" function called once for each row in an aggregate group. - fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A - /// Computes and sets the final result. + fn step(&self, &mut Context, &mut A) -> Result<()>; + /// Computes and returns the final result. fn finalize(&self, &A) -> Result; } @@ -385,6 +389,11 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + /// Attach a user-defined aggregate function to this database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. pub fn create_aggregate_function(&self, fn_name: &str, n_arg: c_int, @@ -402,7 +411,7 @@ impl Connection { /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments - /// given to `create_scalar_function`. + /// given to `create_scalar_function` or `create_aggregate_function`. /// /// # Failure /// @@ -793,9 +802,9 @@ mod test { .unwrap(); assert_eq!(4, result); - let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ - SELECT 2, 1)"; - let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db.query_row(dual_sum, &[], |r| (r.get(0), r.get(1))) .unwrap(); assert_eq!((4, 2), result); } From 726bd599325ff2be273a4ceab568221d0071f9a6 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 11:36:01 -0500 Subject: [PATCH 05/11] Fix typo "rowss" in docs. --- src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/error.rs b/src/error.rs index bf128eb..ba93df8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,7 +33,7 @@ pub enum Error { /// Error converting a file path to a string. InvalidPath(PathBuf), - /// Error returned when an `execute` call returns rowss. + /// Error returned when an `execute` call returns rows. ExecuteReturnedResults, /// Error when a query that was expected to return at least one row (e.g., for `query_row`) From e4819b6adc994e46b73845adce4ac22a195feaf8 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 11:42:39 -0500 Subject: [PATCH 06/11] Give Aggregate::finalize ownership of the context it created --- src/functions.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index d856ef6..eb54688 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -342,7 +342,7 @@ pub trait Aggregate where T: ToResult { /// "step" function called once for each row in an aggregate group. fn step(&self, &mut Context, &mut A) -> Result<()>; /// Computes and returns the final result. - fn finalize(&self, &A) -> Result; + fn finalize(&self, A) -> Result; } @@ -549,7 +549,7 @@ impl InnerConnection { let ac: *mut A = *pac; let a = Box::from_raw(mem::transmute(ac)); // to be freed - match (*boxed_aggr).finalize(&a) { + match (*boxed_aggr).finalize(*a) { Ok(r) => r.set_result(ctx), Err(Error::SqliteFailure(err, s)) => { ffi::sqlite3_result_error_code(ctx, err.extended_code); @@ -783,8 +783,8 @@ mod test { Ok(()) } - fn finalize(&self, sum: &i64) -> Result { - Ok(*sum) + fn finalize(&self, sum: i64) -> Result { + Ok(sum) } } From b189f6ba66f009c807d593bf3f725e44ac6555ac Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 12:33:32 -0500 Subject: [PATCH 07/11] Change how Aggregate works when called on no rows. Before this commit, if the aggregate function was called on 0 rows, it would always return NULL (and never call Aggregate::init() or finalize()). Now, init() and finalize() are always called to get the result of the function, even if step() is never called. --- src/functions.rs | 121 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index eb54688..c8aee84 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -155,14 +155,6 @@ impl ToResult for Null { } } - -// sqlite3_result_error_code, c_int -// sqlite3_result_error_nomem -// sqlite3_result_error_toobig -// sqlite3_result_error, *const c_char, c_int -// sqlite3_result_zeroblob -// sqlite3_result_value - /// A trait for types that can be created from a SQLite function parameter value. pub trait FromValue: Sized { unsafe fn parameter_value(v: *mut sqlite3_value) -> Result; @@ -337,15 +329,19 @@ impl<'a> Context<'a> { /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. pub trait Aggregate where T: ToResult { - /// Initializes the aggregation context. + /// Initializes the aggregation context. Will be called exactly once for each + /// invocation of the function. fn init(&self) -> A; - /// "step" function called once for each row in an aggregate group. + + /// "step" function called once for each row in an aggregate group. May be called + /// 0 times if there are no rows. fn step(&self, &mut Context, &mut A) -> Result<()>; - /// Computes and returns the final result. + + /// Computes and returns the final result. Will be called exactly once for each + /// invocation of the function. fn finalize(&self, A) -> Result; } - impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -489,9 +485,26 @@ impl InnerConnection { where D: Aggregate, T: ToResult { - unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) + // Get our aggregation context from the sqlite3_context. + unsafe fn aggregate_context(agg: &D, ctx: *mut sqlite3_context) -> Result<*mut A> + where D: Aggregate, + T: ToResult + { + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) + as *mut *mut A; + if pac.is_null() { + return Err(Error::SqliteFailure(ffi::Error::new(ffi::SQLITE_NOMEM), None)); + } + if (*pac).is_null() { + let a = agg.init(); + *pac = Box::into_raw(Box::new(a)); + } + Ok(*pac) + } + + unsafe extern "C" fn call_boxed_step(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) where D: Aggregate, T: ToResult { @@ -499,18 +512,12 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. - let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; - if pac.is_null() { - ffi::sqlite3_result_error_nomem(ctx); - return; - } - let ac: *mut A = if (*pac).is_null() { - let a = (*boxed_aggr).init(); - *pac = Box::into_raw(Box::new(a)); - *pac - } else { - *pac + let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { + Ok(agg_ctx) => agg_ctx, + Err(_) => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } }; let mut ctx = Context { @@ -518,7 +525,7 @@ impl InnerConnection { args: slice::from_raw_parts(argv, argc as usize), }; - match (*boxed_aggr).step(&mut ctx, &mut *ac) { + match (*boxed_aggr).step(&mut ctx, &mut *agg_ctx) { Ok(_) => {} Err(Error::SqliteFailure(err, s)) => { ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); @@ -534,6 +541,7 @@ impl InnerConnection { } }; } + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, T: ToResult @@ -542,12 +550,15 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; - if pac.is_null() || (*pac).is_null() { - return; - } - let ac: *mut A = *pac; - let a = Box::from_raw(mem::transmute(ac)); // to be freed + let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { + Ok(agg_ctx) => agg_ctx, + Err(_) => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let a = Box::from_raw(agg_ctx); // to be freed match (*boxed_aggr).finalize(*a) { Ok(r) => r.set_result(ctx), @@ -579,7 +590,7 @@ impl InnerConnection { flags, mem::transmute(boxed_aggr), None, - Some(call_boxed_closure::), + Some(call_boxed_step::), Some(call_boxed_final::), Some(mem::transmute(free_boxed_value::))) }; @@ -772,14 +783,30 @@ mod test { } struct Sum; + struct Count; - impl Aggregate for Sum { + impl Aggregate, Option> for Sum { + fn init(&self) -> Option { + None + } + + fn step(&self, ctx: &mut Context, sum: &mut Option) -> Result<()> { + *sum = Some(sum.unwrap_or(0) + try!(ctx.get::(0))); + Ok(()) + } + + fn finalize(&self, sum: Option) -> Result> { + Ok(sum) + } + } + + impl Aggregate for Count { fn init(&self) -> i64 { 0 } - fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { - *sum = *sum + try!(ctx.get::(0)); + fn step(&self, _ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += 1; Ok(()) } @@ -792,6 +819,8 @@ mod test { fn test_sum() { let db = Connection::open_in_memory().unwrap(); db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + + // sum should return NULL when given no columns (contrast with count below) let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; let result: Option = db.query_row(no_result, &[], |r| r.get(0)) .unwrap(); @@ -808,4 +837,20 @@ mod test { .unwrap(); assert_eq!((4, 2), result); } + + #[test] + fn test_count() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_count", -1, true, Count).unwrap(); + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.query_row(no_result, &[], |r| r.get(0)).unwrap(); + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(2, result); + } } From 199dfc455b6209558642340e7c66218b66caebd8 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 12:39:16 -0500 Subject: [PATCH 08/11] Internal refactor - extract common error handling code --- src/functions.rs | 45 ++++++++++++++++++++------------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index c8aee84..522f85d 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -502,6 +502,23 @@ impl InnerConnection { Ok(*pac) } + unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { + match err { + Error::SqliteFailure(err, s) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + _ => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } + } + unsafe extern "C" fn call_boxed_step(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) @@ -526,19 +543,8 @@ impl InnerConnection { }; match (*boxed_aggr).step(&mut ctx, &mut *agg_ctx) { - Ok(_) => {} - Err(Error::SqliteFailure(err, s)) => { - ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); - if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { - ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); - } - } - Err(err) => { - ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); - if let Ok(cstr) = str_to_cstring(err.description()) { - ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); - } - } + Ok(_) => {}, + Err(err) => report_aggregate_error(ctx.ctx, err), }; } @@ -562,18 +568,7 @@ impl InnerConnection { match (*boxed_aggr).finalize(*a) { Ok(r) => r.set_result(ctx), - Err(Error::SqliteFailure(err, s)) => { - ffi::sqlite3_result_error_code(ctx, err.extended_code); - if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } - } - Err(err) => { - ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); - if let Ok(cstr) = str_to_cstring(err.description()) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } - } + Err(err) => report_aggregate_error(ctx, err), }; } From 267018b80dbdbc2a31752d019d8fc35a94eb1557 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 12:40:23 -0500 Subject: [PATCH 09/11] Update Changelog with aggregate functions note --- Changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/Changelog.md b/Changelog.md index 9e27d21..b7d2363 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,7 @@ * Adds `column_count()` method to `Statement` and `Row`. * Adds `types::Value` for dynamic column types. +* Adds support for user-defined aggregate functions (behind the existing `functions` Cargo feature). * Introduces a `RowIndex` trait allowing columns to be fetched via index (as before) or name (new). # Version 0.6.0 (2015-12-17) From ca761d76977fa4269d5c35031d4fcde53931f0f8 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 15:14:24 -0500 Subject: [PATCH 10/11] Avoid creating an aggregation context unnecessarily if the function is called against 0 rows. --- src/functions.rs | 69 ++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 522f85d..0b9b51a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -329,8 +329,9 @@ impl<'a> Context<'a> { /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. pub trait Aggregate where T: ToResult { - /// Initializes the aggregation context. Will be called exactly once for each - /// invocation of the function. + /// Initializes the aggregation context. Will be called prior to the first call + /// to `step()` to set up the context for an invocation of the function. (Note: + /// `init()` will not be called if the there are no rows.) fn init(&self) -> A; /// "step" function called once for each row in an aggregate group. May be called @@ -338,8 +339,11 @@ pub trait Aggregate where T: ToResult { fn step(&self, &mut Context, &mut A) -> Result<()>; /// Computes and returns the final result. Will be called exactly once for each - /// invocation of the function. - fn finalize(&self, A) -> Result; + /// invocation of the function. If `step()` was called at least once, will be given + /// `Some(A)` (the same `A` as was created by `init` and given to `step`); if `step()` + /// was not called (because the function is running against 0 rows), will be given + /// `None`. + fn finalize(&self, Option) -> Result; } impl Connection { @@ -485,21 +489,13 @@ impl InnerConnection { where D: Aggregate, T: ToResult { - // Get our aggregation context from the sqlite3_context. - unsafe fn aggregate_context(agg: &D, ctx: *mut sqlite3_context) -> Result<*mut A> - where D: Aggregate, - T: ToResult - { + unsafe fn aggregate_context(ctx: *mut sqlite3_context) -> Option<*mut *mut A> { let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; if pac.is_null() { - return Err(Error::SqliteFailure(ffi::Error::new(ffi::SQLITE_NOMEM), None)); + return None; } - if (*pac).is_null() { - let a = agg.init(); - *pac = Box::into_raw(Box::new(a)); - } - Ok(*pac) + Some(pac) } unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { @@ -529,21 +525,25 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { - Ok(agg_ctx) => agg_ctx, - Err(_) => { + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { ffi::sqlite3_result_error_nomem(ctx); return; } }; + if (*pac).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + let mut ctx = Context { ctx: ctx, args: slice::from_raw_parts(argv, argc as usize), }; - match (*boxed_aggr).step(&mut ctx, &mut *agg_ctx) { - Ok(_) => {}, + match (*boxed_aggr).step(&mut ctx, &mut **pac) { + Ok(_) => {} Err(err) => report_aggregate_error(ctx.ctx, err), }; } @@ -556,17 +556,22 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { - Ok(agg_ctx) => agg_ctx, - Err(_) => { + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { ffi::sqlite3_result_error_nomem(ctx); return; } }; - let a = Box::from_raw(agg_ctx); // to be freed + let a: Option = if (*pac).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + }; - match (*boxed_aggr).finalize(*a) { + match (*boxed_aggr).finalize(a) { Ok(r) => r.set_result(ctx), Err(err) => report_aggregate_error(ctx, err), }; @@ -780,13 +785,13 @@ mod test { struct Sum; struct Count; - impl Aggregate, Option> for Sum { - fn init(&self) -> Option { - None + impl Aggregate> for Sum { + fn init(&self) -> i64 { + 0 } - fn step(&self, ctx: &mut Context, sum: &mut Option) -> Result<()> { - *sum = Some(sum.unwrap_or(0) + try!(ctx.get::(0))); + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += try!(ctx.get::(0)); Ok(()) } @@ -805,8 +810,8 @@ mod test { Ok(()) } - fn finalize(&self, sum: i64) -> Result { - Ok(sum) + fn finalize(&self, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) } } From abc5d9e21941a9e6a616e728be33a7c3a02d1860 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 15:15:43 -0500 Subject: [PATCH 11/11] Test all features on Travis --- .travis.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index ba5124f..89334cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,8 @@ script: - cargo build - cargo test - cargo test --features backup + - cargo test --features blob - cargo test --features load_extension - cargo test --features trace - cargo test --features functions - - cargo test --features "backup functions load_extension trace" + - cargo test --features "backup blob functions load_extension trace"