diff --git a/.travis.yml b/.travis.yml index f0cc6fe..fa4b2c4 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,12 +28,9 @@ script: - cargo build --features sqlcipher - cargo build --features "bundled sqlcipher" - cargo test - - cargo test --features backup - - cargo test --features blob - - cargo test --features collation - - cargo test --features functions - - cargo test --features hooks - - cargo test --features limits + - cargo test --features "backup blob" + - cargo test --features "collation functions" + - cargo test --features "hooks limits" - cargo test --features load_extension - cargo test --features trace - cargo test --features chrono @@ -43,7 +40,7 @@ script: - cargo test --features sqlcipher - cargo test --features i128_blob - cargo test --features uuid - - cargo test --features "unlock_notify bundled" + - cargo test --features "bundled unlock_notify window" - cargo test --features "array bundled csvtab vtab" - cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab" - cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen" diff --git a/Cargo.toml b/Cargo.toml index be6cdc1..eefe176 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,8 @@ csvtab = ["csv", "vtab"] array = ["vtab"] # session extension: 3.13.0 session = ["libsqlite3-sys/session", "hooks"] +# window functions: 3.25.0 +window = ["functions"] [dependencies] time = "0.1.0" diff --git a/src/functions.rs b/src/functions.rs index 6a0f54a..06f681d 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -226,6 +226,22 @@ where fn finalize(&self, _: Option) -> Result; } +/// WindowAggregate is the callback interface for user-defined aggregate window +/// function. +#[cfg(feature = "window")] +pub trait WindowAggregate: Aggregate +where + A: RefUnwindSafe + UnwindSafe, + T: ToSql, +{ + /// Returns the current value of the aggregate. Unlike xFinal, the + /// implementation should not delete any context. + fn value(&self, _: Option<&A>) -> Result; + + /// Removes a row from the current window. + fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>; +} + impl Connection { /// Attach a user-defined scalar function to this database connection. /// @@ -294,6 +310,24 @@ impl Connection { .create_aggregate_function(fn_name, n_arg, deterministic, aggr) } + #[cfg(feature = "window")] + pub fn create_window_function( + &self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, + { + self.db + .borrow_mut() + .create_window_function(fn_name, n_arg, deterministic, aggr) + } + /// Removes a user-defined function from this database connection. /// /// `fn_name` and `n_arg` should match the name and number of arguments @@ -386,105 +420,6 @@ impl InnerConnection { D: Aggregate, T: ToSql, { - unsafe fn aggregate_context( - ctx: *mut sqlite3_context, - bytes: usize, - ) -> Option<*mut *mut A> { - let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; - if pac.is_null() { - return None; - } - Some(pac) - } - - unsafe extern "C" fn call_boxed_step( - ctx: *mut sqlite3_context, - argc: c_int, - argv: *mut *mut sqlite3_value, - ) where - A: RefUnwindSafe + UnwindSafe, - D: Aggregate, - T: ToSql, - { - let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { - Some(pac) => pac, - None => { - ffi::sqlite3_result_error_nomem(ctx); - return; - } - }; - - let r = catch_unwind(|| { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; - assert!( - !boxed_aggr.is_null(), - "Internal error - null aggregate pointer" - ); - if (*pac as *mut A).is_null() { - *pac = Box::into_raw(Box::new((*boxed_aggr).init())); - } - let mut ctx = Context { - ctx, - args: slice::from_raw_parts(argv, argc as usize), - }; - (*boxed_aggr).step(&mut ctx, &mut **pac) - }); - let r = match r { - Err(_) => { - report_error(ctx, &Error::UnwindingPanic); - return; - } - Ok(r) => r, - }; - match r { - Ok(_) => {} - Err(err) => report_error(ctx, &err), - }; - } - - unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) - where - A: RefUnwindSafe + UnwindSafe, - D: Aggregate, - T: ToSql, - { - // Within the xFinal callback, it is customary to set N=0 in calls to - // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. - let a: Option = match aggregate_context(ctx, 0) { - Some(pac) => { - if (*pac as *mut A).is_null() { - None - } else { - let a = Box::from_raw(*pac); - Some(*a) - } - } - None => None, - }; - - let r = catch_unwind(|| { - let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; - assert!( - !boxed_aggr.is_null(), - "Internal error - null aggregate pointer" - ); - (*boxed_aggr).finalize(a) - }); - let t = match r { - Err(_) => { - report_error(ctx, &Error::UnwindingPanic); - return; - } - Ok(r) => r, - }; - let t = t.as_ref().map(|t| ToSql::to_sql(t)); - match t { - Ok(Ok(ref value)) => set_result(ctx, value), - Ok(Err(err)) => report_error(ctx, &err), - Err(err) => report_error(ctx, err), - } - } - let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); let c_name = str_to_cstring(fn_name)?; let mut flags = ffi::SQLITE_UTF8; @@ -507,6 +442,42 @@ impl InnerConnection { self.decode_result(r) } + #[cfg(feature = "window")] + fn create_window_function( + &mut self, + fn_name: &str, + n_arg: c_int, + deterministic: bool, + aggr: W, + ) -> Result<()> + where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, + { + let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr)); + let c_name = str_to_cstring(fn_name)?; + let mut flags = ffi::SQLITE_UTF8; + if deterministic { + flags |= ffi::SQLITE_DETERMINISTIC; + } + let r = unsafe { + ffi::sqlite3_create_window_function( + self.db(), + c_name.as_ptr(), + n_arg, + flags, + boxed_aggr as *mut c_void, + Some(call_boxed_step::), + Some(call_boxed_final::), + Some(call_boxed_value::), + Some(call_boxed_inverse::), + Some(free_boxed_value::), + ) + }; + self.decode_result(r) + } + fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let c_name = str_to_cstring(fn_name)?; let r = unsafe { @@ -526,6 +497,189 @@ impl InnerConnection { } } +unsafe fn aggregate_context(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> { + let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; + if pac.is_null() { + return None; + } + Some(pac) +} + +unsafe extern "C" fn call_boxed_step( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate, + T: ToSql, +{ + let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + if (*pac as *mut A).is_null() { + *pac = Box::into_raw(Box::new((*boxed_aggr).init())); + } + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).step(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_inverse( + ctx: *mut sqlite3_context, + argc: c_int, + argv: *mut *mut sqlite3_value, +) where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, +{ + let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) { + Some(pac) => pac, + None => { + ffi::sqlite3_result_error_nomem(ctx); + return; + } + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + let mut ctx = Context { + ctx, + args: slice::from_raw_parts(argv, argc as usize), + }; + (*boxed_aggr).inverse(&mut ctx, &mut **pac) + }); + let r = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + match r { + Ok(_) => {} + Err(err) => report_error(ctx, &err), + }; +} + +unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + D: Aggregate, + T: ToSql, +{ + // Within the xFinal callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = Box::from_raw(*pac); + Some(*a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).finalize(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + +#[cfg(feature = "window")] +unsafe extern "C" fn call_boxed_value(ctx: *mut sqlite3_context) +where + A: RefUnwindSafe + UnwindSafe, + W: WindowAggregate, + T: ToSql, +{ + // Within the xValue callback, it is customary to set N=0 in calls to + // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur. + let a: Option<&A> = match aggregate_context(ctx, 0) { + Some(pac) => { + if (*pac as *mut A).is_null() { + None + } else { + let a = &**pac; + Some(a) + } + } + None => None, + }; + + let r = catch_unwind(|| { + let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W; + assert!( + !boxed_aggr.is_null(), + "Internal error - null aggregate pointer" + ); + (*boxed_aggr).value(a) + }); + let t = match r { + Err(_) => { + report_error(ctx, &Error::UnwindingPanic); + return; + } + Ok(r) => r, + }; + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } +} + #[cfg(test)] mod test { use regex; @@ -535,6 +689,8 @@ mod test { use std::os::raw::c_double; use crate::functions::{Aggregate, Context}; + #[cfg(feature = "window")] + use crate::functions::WindowAggregate; use crate::{Connection, Error, Result, NO_PARAMS}; fn half(ctx: &Context<'_>) -> Result { @@ -752,4 +908,58 @@ mod test { let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); assert_eq!(2, result); } + + #[cfg(feature = "window")] + impl WindowAggregate> for Sum { + fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> { + *sum -= ctx.get::(0)?; + Ok(()) + } + + fn value(&self, sum: Option<&i64>) -> Result> { + Ok(sum.copied()) + } + } + + #[test] + #[cfg(feature = "window")] + fn test_window() { + use fallible_iterator::FallibleIterator; + + let db = Connection::open_in_memory().unwrap(); + db.create_window_function("sumint", 1, true, Sum).unwrap(); + db.execute_batch( + "CREATE TABLE t3(x, y); + INSERT INTO t3 VALUES('a', 4), + ('b', 5), + ('c', 3), + ('d', 8), + ('e', 1);", + ) + .unwrap(); + + let mut stmt = db + .prepare( + "SELECT x, sumint(y) OVER ( + ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING + ) AS sum_y + FROM t3 ORDER BY x;", + ) + .unwrap(); + + let results: Vec<(String, i64)> = stmt + .query(NO_PARAMS) + .unwrap() + .map(|row| Ok((row.get("x")?, row.get("sum_y")?))) + .collect() + .unwrap(); + let expected = vec![ + ("a".to_owned(), 9), + ("b".to_owned(), 12), + ("c".to_owned(), 16), + ("d".to_owned(), 12), + ("e".to_owned(), 9), + ]; + assert_eq!(expected, results); + } }