From ca761d76977fa4269d5c35031d4fcde53931f0f8 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Thu, 7 Jan 2016 15:14:24 -0500 Subject: [PATCH] Avoid creating an aggregation context unnecessarily if the function is called against 0 rows. --- src/functions.rs | 69 ++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 32 deletions(-) diff --git a/src/functions.rs b/src/functions.rs index 522f85d..0b9b51a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -329,8 +329,9 @@ impl<'a> Context<'a> { /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. pub trait Aggregate where T: ToResult { - /// Initializes the aggregation context. Will be called exactly once for each - /// invocation of the function. + /// Initializes the aggregation context. Will be called prior to the first call + /// to `step()` to set up the context for an invocation of the function. (Note: + /// `init()` will not be called if the there are no rows.) fn init(&self) -> A; /// "step" function called once for each row in an aggregate group. May be called @@ -338,8 +339,11 @@ pub trait Aggregate where T: ToResult { fn step(&self, &mut Context, &mut A) -> Result<()>; /// Computes and returns the final result. Will be called exactly once for each - /// invocation of the function. - fn finalize(&self, A) -> Result; + /// invocation of the function. If `step()` was called at least once, will be given + /// `Some(A)` (the same `A` as was created by `init` and given to `step`); if `step()` + /// was not called (because the function is running against 0 rows), will be given + /// `None`. + fn finalize(&self, Option) -> Result; } impl Connection { @@ -485,21 +489,13 @@ impl InnerConnection { where D: Aggregate, T: ToResult { - // Get our aggregation context from the sqlite3_context. - unsafe fn aggregate_context(agg: &D, ctx: *mut sqlite3_context) -> Result<*mut A> - where D: Aggregate, - T: ToResult - { + unsafe fn aggregate_context(ctx: *mut sqlite3_context) -> Option<*mut *mut A> { let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) as *mut *mut A; if pac.is_null() { - return Err(Error::SqliteFailure(ffi::Error::new(ffi::SQLITE_NOMEM), None)); + return None; } - if (*pac).is_null() { - let a = agg.init(); - *pac = Box::into_raw(Box::new(a)); - } - Ok(*pac) + Some(pac) } unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { @@ -529,21 +525,25 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { - Ok(agg_ctx) => agg_ctx, - Err(_) => { + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { ffi::sqlite3_result_error_nomem(ctx); return; } }; + if (*pac).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + let mut ctx = Context { ctx: ctx, args: slice::from_raw_parts(argv, argc as usize), }; - match (*boxed_aggr).step(&mut ctx, &mut *agg_ctx) { - Ok(_) => {}, + match (*boxed_aggr).step(&mut ctx, &mut **pac) { + Ok(_) => {} Err(err) => report_aggregate_error(ctx.ctx, err), }; } @@ -556,17 +556,22 @@ impl InnerConnection { assert!(!boxed_aggr.is_null(), "Internal error - null aggregate pointer"); - let agg_ctx = match aggregate_context(&*boxed_aggr, ctx) { - Ok(agg_ctx) => agg_ctx, - Err(_) => { + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { ffi::sqlite3_result_error_nomem(ctx); return; } }; - let a = Box::from_raw(agg_ctx); // to be freed + let a: Option = if (*pac).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + }; - match (*boxed_aggr).finalize(*a) { + match (*boxed_aggr).finalize(a) { Ok(r) => r.set_result(ctx), Err(err) => report_aggregate_error(ctx, err), }; @@ -780,13 +785,13 @@ mod test { struct Sum; struct Count; - impl Aggregate, Option> for Sum { - fn init(&self) -> Option { - None + impl Aggregate> for Sum { + fn init(&self) -> i64 { + 0 } - fn step(&self, ctx: &mut Context, sum: &mut Option) -> Result<()> { - *sum = Some(sum.unwrap_or(0) + try!(ctx.get::(0))); + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += try!(ctx.get::(0)); Ok(()) } @@ -805,8 +810,8 @@ mod test { Ok(()) } - fn finalize(&self, sum: i64) -> Result { - Ok(sum) + fn finalize(&self, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) } }