mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-26 19:38:54 +08:00 
			
		
		
		
	Merge pull request #115 from jgallagher/gwenn-aggregate
Add support for user-defined aggregate functions.
This commit is contained in:
		| @@ -9,7 +9,8 @@ script: | ||||
|     - cargo build | ||||
|     - cargo test | ||||
|     - cargo test --features backup | ||||
|     - cargo test --features blob | ||||
|     - cargo test --features load_extension | ||||
|     - cargo test --features trace | ||||
|     - cargo test --features functions | ||||
|     - cargo test --features "backup functions load_extension trace" | ||||
|     - cargo test --features "backup blob functions load_extension trace" | ||||
|   | ||||
| @@ -2,6 +2,7 @@ | ||||
|  | ||||
| * Adds `column_count()` method to `Statement` and `Row`. | ||||
| * Adds `types::Value` for dynamic column types. | ||||
| * Adds support for user-defined aggregate functions (behind the existing `functions` Cargo feature). | ||||
| * Introduces a `RowIndex` trait allowing columns to be fetched via index (as before) or name (new). | ||||
|  | ||||
| # Version 0.6.0 (2015-12-17) | ||||
|   | ||||
							
								
								
									
										244
									
								
								src/functions.rs
									
									
									
									
									
								
							
							
						
						
									
										244
									
								
								src/functions.rs
									
									
									
									
									
								
							| @@ -155,14 +155,6 @@ impl ToResult for Null { | ||||
|     } | ||||
| } | ||||
|  | ||||
|  | ||||
| // sqlite3_result_error_code, c_int | ||||
| // sqlite3_result_error_nomem | ||||
| // sqlite3_result_error_toobig | ||||
| // sqlite3_result_error, *const c_char, c_int | ||||
| // sqlite3_result_zeroblob | ||||
| // sqlite3_result_value | ||||
|  | ||||
| /// A trait for types that can be created from a SQLite function parameter value. | ||||
| pub trait FromValue: Sized { | ||||
|     unsafe fn parameter_value(v: *mut sqlite3_value) -> Result<Self>; | ||||
| @@ -332,6 +324,28 @@ impl<'a> Context<'a> { | ||||
|     } | ||||
| } | ||||
|  | ||||
| /// Aggregate is the callback interface for user-defined aggregate function. | ||||
| /// | ||||
| /// `A` is the type of the aggregation context and `T` is the type of the final result. | ||||
| /// Implementations should be stateless. | ||||
| pub trait Aggregate<A, T> where T: ToResult { | ||||
|     /// Initializes the aggregation context. Will be called prior to the first call | ||||
|     /// to `step()` to set up the context for an invocation of the function. (Note: | ||||
|     /// `init()` will not be called if the there are no rows.) | ||||
|     fn init(&self) -> A; | ||||
|  | ||||
|     /// "step" function called once for each row in an aggregate group. May be called | ||||
|     /// 0 times if there are no rows. | ||||
|     fn step(&self, &mut Context, &mut A) -> Result<()>; | ||||
|  | ||||
|     /// Computes and returns the final result. Will be called exactly once for each | ||||
|     /// invocation of the function. If `step()` was called at least once, will be given | ||||
|     /// `Some(A)` (the same `A` as was created by `init` and given to `step`); if `step()` | ||||
|     /// was not called (because the function is running against 0 rows), will be given | ||||
|     /// `None`. | ||||
|     fn finalize(&self, Option<A>) -> Result<T>; | ||||
| } | ||||
|  | ||||
| impl Connection { | ||||
|     /// Attach a user-defined scalar function to this database connection. | ||||
|     /// | ||||
| @@ -375,10 +389,29 @@ impl Connection { | ||||
|         self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) | ||||
|     } | ||||
|  | ||||
|     /// Attach a user-defined aggregate function to this database connection. | ||||
|     /// | ||||
|     /// # Failure | ||||
|     /// | ||||
|     /// Will return Err if the function could not be attached to the connection. | ||||
|     pub fn create_aggregate_function<A, D, T>(&self, | ||||
|                                               fn_name: &str, | ||||
|                                               n_arg: c_int, | ||||
|                                               deterministic: bool, | ||||
|                                               aggr: D) | ||||
|                                               -> Result<()> | ||||
|         where D: Aggregate<A, T>, | ||||
|               T: ToResult | ||||
|     { | ||||
|         self.db | ||||
|             .borrow_mut() | ||||
|             .create_aggregate_function(fn_name, n_arg, deterministic, aggr) | ||||
|     } | ||||
|  | ||||
|     /// Removes a user-defined function from this database connection. | ||||
|     /// | ||||
|     /// `fn_name` and `n_arg` should match the name and number of arguments | ||||
|     /// given to `create_scalar_function`. | ||||
|     /// given to `create_scalar_function` or `create_aggregate_function`. | ||||
|     /// | ||||
|     /// # Failure | ||||
|     /// | ||||
| @@ -417,7 +450,7 @@ impl InnerConnection { | ||||
|                     if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { | ||||
|                         ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 }, | ||||
|                 } | ||||
|                 Err(err) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); | ||||
|                     if let Ok(cstr) = str_to_cstring(err.description()) { | ||||
| @@ -447,6 +480,123 @@ impl InnerConnection { | ||||
|         self.decode_result(r) | ||||
|     } | ||||
|  | ||||
|     fn create_aggregate_function<A, D, T>(&mut self, | ||||
|                                           fn_name: &str, | ||||
|                                           n_arg: c_int, | ||||
|                                           deterministic: bool, | ||||
|                                           aggr: D) | ||||
|                                           -> Result<()> | ||||
|         where D: Aggregate<A, T>, | ||||
|               T: ToResult | ||||
|     { | ||||
|         unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context) -> Option<*mut *mut A> { | ||||
|             let pac = ffi::sqlite3_aggregate_context(ctx, ::std::mem::size_of::<*mut A>() as c_int) | ||||
|                             as *mut *mut A; | ||||
|             if pac.is_null() { | ||||
|                 return None; | ||||
|             } | ||||
|             Some(pac) | ||||
|         } | ||||
|  | ||||
|         unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { | ||||
|             match err { | ||||
|                 Error::SqliteFailure(err, s) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx, err.extended_code); | ||||
|                     if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { | ||||
|                         ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|                 _ => { | ||||
|                     ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); | ||||
|                     if let Ok(cstr) = str_to_cstring(err.description()) { | ||||
|                         ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         unsafe extern "C" fn call_boxed_step<A, D, T>(ctx: *mut sqlite3_context, | ||||
|                                                       argc: c_int, | ||||
|                                                       argv: *mut *mut sqlite3_value) | ||||
|             where D: Aggregate<A, T>, | ||||
|                   T: ToResult | ||||
|         { | ||||
|             let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); | ||||
|             assert!(!boxed_aggr.is_null(), | ||||
|                     "Internal error - null aggregate pointer"); | ||||
|  | ||||
|             let pac = match aggregate_context(ctx) { | ||||
|                 Some(pac) => pac, | ||||
|                 None => { | ||||
|                     ffi::sqlite3_result_error_nomem(ctx); | ||||
|                     return; | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             if (*pac).is_null() { | ||||
|                 *pac = Box::into_raw(Box::new((*boxed_aggr).init())); | ||||
|             } | ||||
|  | ||||
|             let mut ctx = Context { | ||||
|                 ctx: ctx, | ||||
|                 args: slice::from_raw_parts(argv, argc as usize), | ||||
|             }; | ||||
|  | ||||
|             match (*boxed_aggr).step(&mut ctx, &mut **pac) { | ||||
|                 Ok(_) => {} | ||||
|                 Err(err) => report_aggregate_error(ctx.ctx, err), | ||||
|             }; | ||||
|         } | ||||
|  | ||||
|         unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) | ||||
|             where D: Aggregate<A, T>, | ||||
|                   T: ToResult | ||||
|         { | ||||
|             let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); | ||||
|             assert!(!boxed_aggr.is_null(), | ||||
|                     "Internal error - null aggregate pointer"); | ||||
|  | ||||
|             let pac = match aggregate_context(ctx) { | ||||
|                 Some(pac) => pac, | ||||
|                 None => { | ||||
|                     ffi::sqlite3_result_error_nomem(ctx); | ||||
|                     return; | ||||
|                 } | ||||
|             }; | ||||
|  | ||||
|             let a: Option<A> = if (*pac).is_null() { | ||||
|                 None | ||||
|             } else { | ||||
|                 let a = Box::from_raw(*pac); | ||||
|                 Some(*a) | ||||
|             }; | ||||
|  | ||||
|             match (*boxed_aggr).finalize(a) { | ||||
|                 Ok(r) => r.set_result(ctx), | ||||
|                 Err(err) => report_aggregate_error(ctx, err), | ||||
|             }; | ||||
|         } | ||||
|  | ||||
|         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); | ||||
|         let c_name = try!(str_to_cstring(fn_name)); | ||||
|         let mut flags = ffi::SQLITE_UTF8; | ||||
|         if deterministic { | ||||
|             flags |= ffi::SQLITE_DETERMINISTIC; | ||||
|         } | ||||
|         let r = unsafe { | ||||
|             ffi::sqlite3_create_function_v2(self.db(), | ||||
|                                             c_name.as_ptr(), | ||||
|                                             n_arg, | ||||
|                                             flags, | ||||
|                                             mem::transmute(boxed_aggr), | ||||
|                                             None, | ||||
|                                             Some(call_boxed_step::<A, D, T>), | ||||
|                                             Some(call_boxed_final::<A, D, T>), | ||||
|                                             Some(mem::transmute(free_boxed_value::<D>))) | ||||
|         }; | ||||
|         self.decode_result(r) | ||||
|     } | ||||
|  | ||||
|     fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { | ||||
|         let c_name = try!(str_to_cstring(fn_name)); | ||||
|         let r = unsafe { | ||||
| @@ -473,7 +623,7 @@ mod test { | ||||
|     use self::regex::Regex; | ||||
|  | ||||
|     use {Connection, Error, Result}; | ||||
|     use functions::Context; | ||||
|     use functions::{Aggregate, Context}; | ||||
|  | ||||
|     fn half(ctx: &Context) -> Result<c_double> { | ||||
|         assert!(ctx.len() == 1, "called with unexpected number of arguments"); | ||||
| @@ -631,4 +781,76 @@ mod test { | ||||
|             assert_eq!(expected, result); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     struct Sum; | ||||
|     struct Count; | ||||
|  | ||||
|     impl Aggregate<i64, Option<i64>> for Sum { | ||||
|         fn init(&self) -> i64 { | ||||
|             0 | ||||
|         } | ||||
|  | ||||
|         fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { | ||||
|             *sum += try!(ctx.get::<i64>(0)); | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn finalize(&self, sum: Option<i64>) -> Result<Option<i64>> { | ||||
|             Ok(sum) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     impl Aggregate<i64, i64> for Count { | ||||
|         fn init(&self) -> i64 { | ||||
|             0 | ||||
|         } | ||||
|  | ||||
|         fn step(&self, _ctx: &mut Context, sum: &mut i64) -> Result<()> { | ||||
|             *sum += 1; | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn finalize(&self, sum: Option<i64>) -> Result<i64> { | ||||
|             Ok(sum.unwrap_or(0)) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_sum() { | ||||
|         let db = Connection::open_in_memory().unwrap(); | ||||
|         db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); | ||||
|  | ||||
|         // sum should return NULL when given no columns (contrast with count below) | ||||
|         let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; | ||||
|         let result: Option<i64> = db.query_row(no_result, &[], |r| r.get(0)) | ||||
|                                     .unwrap(); | ||||
|         assert!(result.is_none()); | ||||
|  | ||||
|         let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; | ||||
|         let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) | ||||
|                             .unwrap(); | ||||
|         assert_eq!(4, result); | ||||
|  | ||||
|         let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ | ||||
|                         2, 1)"; | ||||
|         let result: (i64, i64) = db.query_row(dual_sum, &[], |r| (r.get(0), r.get(1))) | ||||
|                                    .unwrap(); | ||||
|         assert_eq!((4, 2), result); | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_count() { | ||||
|         let db = Connection::open_in_memory().unwrap(); | ||||
|         db.create_aggregate_function("my_count", -1, true, Count).unwrap(); | ||||
|  | ||||
|         // count should return 0 when given no columns (contrast with sum above) | ||||
|         let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; | ||||
|         let result: i64 = db.query_row(no_result, &[], |r| r.get(0)).unwrap(); | ||||
|         assert_eq!(result, 0); | ||||
|  | ||||
|         let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; | ||||
|         let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) | ||||
|                             .unwrap(); | ||||
|         assert_eq!(2, result); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user