mirror of
				https://github.com/isar/rusqlite.git
				synced 2025-10-26 11:20:23 +08:00 
			
		
		
		
	Test a user-defined aggregate function: my_sum.
This commit is contained in:
		| @@ -336,7 +336,7 @@ pub trait Aggregate<A, T> where T: ToResult { | ||||
|     /// Initializes the aggregation context. | ||||
|     fn init(&self) -> A; // TODO Validate: Fn(&Context) | ||||
|     /// "step" function called once for each row in an aggregate group. | ||||
|     fn step(&self, &mut Context, &mut A); // TODO Validate: Fn(&Context, A) -> A | ||||
|     fn step(&self, &mut Context, &mut A) -> Result<()>; // TODO Validate: Fn(&Context, A) -> A | ||||
|     /// Computes and sets the final result. | ||||
|     fn finalize(&self, &A) -> Result<T>; | ||||
| } | ||||
| @@ -441,7 +441,7 @@ impl InnerConnection { | ||||
|                     if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { | ||||
|                         ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 }, | ||||
|                 } | ||||
|                 Err(err) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); | ||||
|                     if let Ok(cstr) = str_to_cstring(err.description()) { | ||||
| @@ -498,7 +498,8 @@ impl InnerConnection { | ||||
|             } | ||||
|             let ac: *mut A = if (*pac).is_null() { | ||||
|                 let a = (*boxed_aggr).init(); | ||||
|                 Box::into_raw(Box::new(a)) | ||||
|                 *pac = Box::into_raw(Box::new(a)); | ||||
|                 *pac | ||||
|             } else { | ||||
|                 *pac | ||||
|             }; | ||||
| @@ -508,7 +509,21 @@ impl InnerConnection { | ||||
|                 args: slice::from_raw_parts(argv, argc as usize), | ||||
|             }; | ||||
|  | ||||
|             (*boxed_aggr).step(&mut ctx, &mut *ac); | ||||
|             match (*boxed_aggr).step(&mut ctx, &mut *ac) { | ||||
|                 Ok(_) => {} | ||||
|                 Err(Error::SqliteFailure(err, s)) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); | ||||
|                     if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { | ||||
|                         ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|                 Err(err) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); | ||||
|                     if let Ok(cstr) = str_to_cstring(err.description()) { | ||||
|                         ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|             }; | ||||
|         } | ||||
|         unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) | ||||
|             where D: Aggregate<A, T>, | ||||
| @@ -523,16 +538,23 @@ impl InnerConnection { | ||||
|                 return; | ||||
|             } | ||||
|             let ac: *mut A = *pac; | ||||
|             let a = Box::from_raw(mem::transmute(ac)); // to be freed | ||||
|  | ||||
|             match (*boxed_aggr).finalize(&*ac) { | ||||
|             match (*boxed_aggr).finalize(&a) { | ||||
|                 Ok(r) => r.set_result(ctx), | ||||
|                 Err(e) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx, e.code); | ||||
|                     if let Ok(cstr) = str_to_cstring(&e.message) { | ||||
|                 Err(Error::SqliteFailure(err, s)) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx, err.extended_code); | ||||
|                     if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { | ||||
|                         ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|                 Err(err) => { | ||||
|                     ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); | ||||
|                     if let Ok(cstr) = str_to_cstring(err.description()) { | ||||
|                         ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); | ||||
|                     } | ||||
|                 } | ||||
|             }; | ||||
|         } | ||||
|  | ||||
|         let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); | ||||
| @@ -581,7 +603,7 @@ mod test { | ||||
|     use self::regex::Regex; | ||||
|  | ||||
|     use {Connection, Error, Result}; | ||||
|     use functions::Context; | ||||
|     use functions::{Aggregate, Context}; | ||||
|  | ||||
|     fn half(ctx: &Context) -> Result<c_double> { | ||||
|         assert!(ctx.len() == 1, "called with unexpected number of arguments"); | ||||
| @@ -739,4 +761,42 @@ mod test { | ||||
|             assert_eq!(expected, result); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     struct Sum; | ||||
|  | ||||
|     impl Aggregate<i64, i64> for Sum { | ||||
|         fn init(&self) -> i64 { | ||||
|             0 | ||||
|         } | ||||
|  | ||||
|         fn step(&self, ctx: &mut Context, sum: &mut i64) -> Result<()> { | ||||
|             *sum = *sum + try!(ctx.get::<i64>(0)); | ||||
|             Ok(()) | ||||
|         } | ||||
|  | ||||
|         fn finalize(&self, sum: &i64) -> Result<i64> { | ||||
|             Ok(*sum) | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     #[test] | ||||
|     fn test_sum() { | ||||
|         let db = Connection::open_in_memory().unwrap(); | ||||
|         db.create_aggregate_function("my_sum", 1, true, Sum).unwrap(); | ||||
|         let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; | ||||
|         let result: Option<i64> = db.query_row(no_result, &[], |r| r.get(0)) | ||||
|                                     .unwrap(); | ||||
|         assert!(result.is_none()); | ||||
|  | ||||
|         let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; | ||||
|         let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) | ||||
|                             .unwrap(); | ||||
|         assert_eq!(4, result); | ||||
|  | ||||
|         let single_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL \ | ||||
|                           SELECT 2, 1)"; | ||||
|         let result: (i64, i64) = db.query_row(single_sum, &[], |r| (r.get(0), r.get(1))) | ||||
|                                    .unwrap(); | ||||
|         assert_eq!((4, 2), result); | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user