From b189f6ba66f009c807d593bf3f725e44ac6555ac Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 12:33:32 -0500 Subject: [PATCH] Change how Aggregate works when called on no rows. Before this commit, if the aggregate function was called on 0 rows, it would always return NULL (and never call Aggregate::init() or finalize()). Now, init() and finalize() are always called to get the result of the function, even if step() is never called. --- src/functions.rs | 121 ++++++++++++++++++++++++++++++++--------------- 1 file changed, 83 insertions(+), 38 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index eb54688..c8aee84 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -155,14 +155,6 @@ impl ToResult for Null { } } - -// sqlite3_result_error_code, c_int -// sqlite3_result_error_nomem -// sqlite3_result_error_toobig -// sqlite3_result_error, *const c_char, c_int -// sqlite3_result_zeroblob -// sqlite3_result_value - /// A trait for types that can be created from a SQLite function parameter value. pub trait FromValue: Sized { unsafe fn parameter_value(v: *mut sqlite3_value) -> Result; @@ -337,15 +329,19 @@ impl<'a> Context<'a> { /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. pub trait Aggregate where T: ToResult { - /// Initializes the aggregation context. + /// Initializes the aggregation context. Will be called exactly once for each + /// invocation of the function. fn init(&self) -> A; - /// "step" function called once for each row in an aggregate group. + + /// "step" function called once for each row in an aggregate group. May be called + /// 0 times if there are no rows. fn step(&self, &mut Context, &mut A) -> Result<()>; - /// Computes and returns the final result. + + /// Computes and returns the final result. Will be called exactly once for each + /// invocation of the function. fn finalize(&self, A) -> Result; } - impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -489,9 +485,26 @@ impl InnerConnection { where D: Aggregate, T: ToResult { - unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value) + // Get our aggregation context from the sqlite3_context. + unsafe fn aggregate_context(agg: &D, ctx: *mut sqlite3_context) -> Result<*mut A> + where D: Aggregate, + T: ToResult + { + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) + as *mut *mut A; + if pac.is_null() { + return Err(Error::SqliteFailure(ffi::Error::new(ffi::SQLITE_NOMEM), None)); + } + if (*pac).is_null() { + let a = agg.init(); + *pac = Box::into_raw(Box::new(a)); + } + Ok(*pac) + } + + unsafe extern "C" fn call_boxed_step(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) where D: Aggregate, T: ToResult { @@ -499,18 +512,12 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - // TODO Validate: double indirection: `pac` allocated/freed by SQLite and `ac` allocated/freed by Rust. - let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; - if pac.is_null() { - ffi::sqlite3_result_error_nomem(ctx); - return; - } - let ac: *mut A = if (*pac).is_null() { - let a = (*boxed_aggr).init(); - *pac = Box::into_raw(Box::new(a)); - *pac - } else { - *pac + let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { + Ok(agg_ctx) => agg_ctx, + Err(_) => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } }; let mut ctx = Context { @@ -518,7 +525,7 @@ impl InnerConnection { args: slice::from_raw_parts(argv, argc as usize), }; - match (*boxed_aggr).step(&mut ctx, &mut *ac) { + match (*boxed_aggr).step(&mut ctx, &mut *agg_ctx) { Ok(_) => {} Err(Error::SqliteFailure(err, s)) => { ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); @@ -534,6 +541,7 @@ impl InnerConnection { } }; } + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, T: ToResult @@ -542,12 +550,15 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let pac = ffi::sqlite3_aggregate_context(ctx, 0) as *mut *mut A; - if pac.is_null() || (*pac).is_null() { - return; - } - let ac: *mut A = *pac; - let a = Box::from_raw(mem::transmute(ac)); // to be freed + let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { + Ok(agg_ctx) => agg_ctx, + Err(_) => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let a = Box::from_raw(agg_ctx); // to be freed match (*boxed_aggr).finalize(*a) { Ok(r) => r.set_result(ctx), @@ -579,7 +590,7 @@ impl InnerConnection { flags, mem::transmute(boxed_aggr), None, - Some(call_boxed_closure::), + Some(call_boxed_step::), Some(call_boxed_final::), Some(mem::transmute(free_boxed_value::))) }; @@ -772,14 +783,30 @@ mod test { } struct Sum; + struct Count; - impl Aggregate for Sum { + impl Aggregate, Option> for Sum { + fn init(&self) -> Option { + None + } + + fn step(&self, ctx: &mut Context, sum: &mut Option) -> Result<()> { + *sum = Some(sum.unwrap_or(0) + try!(ctx.get::(0))); + Ok(()) + } + + fn finalize(&self, sum: Option) -> Result> { + Ok(sum) + } + } + + impl Aggregate for Count { fn init(&self) -> i64 { 0 } - fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { - *sum = *sum + try!(ctx.get::(0)); + fn step(&self, _ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += 1; Ok(()) } @@ -792,6 +819,8 @@ mod test { fn test_sum() { let db = Connection::open_in_memory().unwrap(); db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + + // sum should return NULL when given no columns (contrast with count below) let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; let result: Option = db.query_row(no_result, &[], |r| r.get(0)) .unwrap(); @@ -808,4 +837,20 @@ mod test { .unwrap(); assert_eq!((4, 2), result); } + + #[test] + fn test_count() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_count", -1, true, Count).unwrap(); + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.query_row(no_result, &[], |r| r.get(0)).unwrap(); + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(2, result); + } }