diff --git a/src/functions.rs b/src/functions.rs index e2048c2..a36b6fa 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -336,7 +336,7 @@ pub trait Aggregate where T: ToResult { /// Initializes the aggregation context. fn init(&self) -> A; // TODO Validate: Fn(&Context) /// "step" function called once for each row in an aggregate group. - fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A + fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A /// Computes and sets the final result. fn finalize(&self, &A) -> Result; } @@ -441,7 +441,7 @@ impl InnerConnection { if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } - }, + } Err(err) => { ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); if let Ok(cstr) = str_to_cstring(err.description()) { @@ -498,7 +498,8 @@ impl InnerConnection { } let ac: *mut A = if (*pac).is_null() { let a = (*boxed_aggr).init(); - Box::into_raw(Box::new(a)) + *pac = Box::into_raw(Box::new(a)); + *pac } else { *pac }; @@ -508,7 +509,21 @@ impl InnerConnection { args: slice::from_raw_parts(argv, argc as usize), }; - (*boxed_aggr).step(&mut ctx, &mut *ac); + match (*boxed_aggr).step(&mut ctx, &mut *ac) { + Ok(_) => {} + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + Err(err) => { + ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); + } + } + }; } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, @@ -523,16 +538,23 @@ impl InnerConnection { return; } let ac: *mut A = *pac; + let a = Box::from_raw(mem::transmute(ac)); // to be freed - match (*boxed_aggr).finalize(&*ac) { + match (*boxed_aggr).finalize(&a) { Ok(r) => r.set_result(ctx), - Err(e) => { - ffi::sqlite3_result_error_code(ctx, e.code); - if let Ok(cstr) = str_to_cstring(&e.message) { + Err(Error::SqliteFailure(err, s)) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } } - } + Err(err) => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + }; } let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); @@ -581,7 +603,7 @@ mod test { use self::regex::Regex; use {Connection, Error, Result}; - use functions::Context; + use functions::{Aggregate, Context}; fn half(ctx: &Context) -> Result { assert!(ctx.len() == 1, "called with unexpected number of arguments"); @@ -739,4 +761,42 @@ mod test { assert_eq!(expected, result); } } + + struct Sum; + + impl Aggregate for Sum { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum = *sum + try!(ctx.get::(0)); + Ok(()) + } + + fn finalize(&self, sum: &i64) -> Result { + Ok(*sum) + } + } + + #[test] + fn test_sum() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option = db.query_row(no_result, &[], |r| r.get(0)) + .unwrap(); + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(4, result); + + let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ + SELECT 2, 1)"; + let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) + .unwrap(); + assert_eq!((4, 2), result); + } }