diff --git a/.travis.yml b/.travis.yml index ba5124f..89334cf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -9,7 +9,8 @@ script: - cargo build - cargo test - cargo test --features backup + - cargo test --features blob - cargo test --features load_extension - cargo test --features trace - cargo test --features functions - - cargo test --features "backup functions load_extension trace" + - cargo test --features "backup blob functions load_extension trace" diff --git a/Changelog.md b/Changelog.md index 9e27d21..b7d2363 100644 --- a/Changelog.md +++ b/Changelog.md @@ -2,6 +2,7 @@ * Adds `column_count()` method to `Statement` and `Row`. * Adds `types::Value` for dynamic column types. +* Adds support for user-defined aggregate functions (behind the existing `functions` Cargo feature). * Introduces a `RowIndex` trait allowing columns to be fetched via index (as before) or name (new). # Version 0.6.0 (2015-12-17) diff --git a/src/functions.rs b/src/functions.rs index d9fecb0..0b9b51a 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -155,14 +155,6 @@ impl ToResult for Null { } } - -// sqlite3_result_error_code, c_int -// sqlite3_result_error_nomem -// sqlite3_result_error_toobig -// sqlite3_result_error, *const c_char, c_int -// sqlite3_result_zeroblob -// sqlite3_result_value - /// A trait for types that can be created from a SQLite function parameter value. pub trait FromValue: Sized { unsafe fn parameter_value(v: *mut sqlite3_value) -> Result; @@ -332,6 +324,28 @@ impl<'a> Context<'a> { } } +/// Aggregate is the callback interface for user-defined aggregate function. +/// +/// `A` is the type of the aggregation context and `T` is the type of the final result. +/// Implementations should be stateless. +pub trait Aggregate where T: ToResult { + /// Initializes the aggregation context. Will be called prior to the first call + /// to `step()` to set up the context for an invocation of the function. (Note: + /// `init()` will not be called if the there are no rows.) + fn init(&self) -> A; + + /// "step" function called once for each row in an aggregate group. May be called + /// 0 times if there are no rows. + fn step(&self, &mut Context, &mut A) -> Result<()>; + + /// Computes and returns the final result. Will be called exactly once for each + /// invocation of the function. If `step()` was called at least once, will be given + /// `Some(A)` (the same `A` as was created by `init` and given to `step`); if `step()` + /// was not called (because the function is running against 0 rows), will be given + /// `None`. + fn finalize(&self, Option) -> Result; +} + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -375,10 +389,29 @@ impl Connection { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } + /// Attach a user-defined aggregate function to this database connection. + /// + /// # Failure + /// + /// Will return Err if the function could not be attached to the connection. + pub fn create_aggregate_function(&self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + self.db + .borrow_mut() + .create_aggregate_function(fn_name, n_arg, deterministic, aggr) + } + /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments - /// given to `create_scalar_function`. + /// given to `create_scalar_function` or `create_aggregate_function`. /// /// # Failure /// @@ -417,7 +450,7 @@ impl InnerConnection { if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); } - }, + } Err(err) => { ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); if let Ok(cstr) = str_to_cstring(err.description()) { @@ -447,6 +480,123 @@ impl InnerConnection { self.decode_result(r) } + fn create_aggregate_function(&mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: D) + -> Result<()> + where D: Aggregate, + T: ToResult + { + unsafe fn aggregate_context(ctx: *mut sqlite3_context) -> Option<*mut *mut A> { + let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) + as *mut *mut A; + if pac.is_null() { + return None; + } + Some(pac) + } + + unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { + match err { + Error::SqliteFailure(err, s) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + _ => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); + } + } + } + } + + unsafe extern "C" fn call_boxed_step(ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); + + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + if (*pac).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + + let mut ctx = Context { + ctx: ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + + match (*boxed_aggr).step(&mut ctx, &mut **pac) { + Ok(_) => {} + Err(err) => report_aggregate_error(ctx.ctx, err), + }; + } + + unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) + where D: Aggregate, + T: ToResult + { + let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); + assert!(!boxed_aggr.is_null(), + "Internal error - null aggregate pointer"); + + let pac = match aggregate_context(ctx) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let a: Option = if (*pac).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + }; + + match (*boxed_aggr).finalize(a) { + Ok(r) => r.set_result(ctx), + Err(err) => report_aggregate_error(ctx, err), + }; + } + + let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); + let c_name = try!(str_to_cstring(fn_name)); + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_function_v2(self.db(), + c_name.as_ptr(), + n_arg, + flags, + mem::transmute(boxed_aggr), + None, + Some(call_boxed_step::), + Some(call_boxed_final::), + Some(mem::transmute(free_boxed_value::))) + }; + self.decode_result(r) + } + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let c_name = try!(str_to_cstring(fn_name)); let r = unsafe { @@ -473,7 +623,7 @@ mod test { use self::regex::Regex; use {Connection, Error, Result}; - use functions::Context; + use functions::{Aggregate, Context}; fn half(ctx: &Context) -> Result { assert!(ctx.len() == 1, "called with unexpected number of arguments"); @@ -631,4 +781,76 @@ mod test { assert_eq!(expected, result); } } + + struct Sum; + struct Count; + + impl Aggregate> for Sum { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += try!(ctx.get::(0)); + Ok(()) + } + + fn finalize(&self, sum: Option) -> Result> { + Ok(sum) + } + } + + impl Aggregate for Count { + fn init(&self) -> i64 { + 0 + } + + fn step(&self, _ctx: &mut Context, sum: &mut i64) -> Result<()> { + *sum += 1; + Ok(()) + } + + fn finalize(&self, sum: Option) -> Result { + Ok(sum.unwrap_or(0)) + } + } + + #[test] + fn test_sum() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); + + // sum should return NULL when given no columns (contrast with count below) + let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: Option = db.query_row(no_result, &[], |r| r.get(0)) + .unwrap(); + assert!(result.is_none()); + + let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(4, result); + + let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ + 2, 1)"; + let result: (i64, i64) = db.query_row(dual_sum, &[], |r| (r.get(0), r.get(1))) + .unwrap(); + assert_eq!((4, 2), result); + } + + #[test] + fn test_count() { + let db = Connection::open_in_memory().unwrap(); + db.create_aggregate_function("my_count", -1, true, Count).unwrap(); + + // count should return 0 when given no columns (contrast with sum above) + let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; + let result: i64 = db.query_row(no_result, &[], |r| r.get(0)).unwrap(); + assert_eq!(result, 0); + + let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; + let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) + .unwrap(); + assert_eq!(2, result); + } }