Add binding to sqlite3_create_window_function

This commit is contained in:
gwenn 2019-06-25 20:33:49 +02:00
parent b632c0f8be
commit f1198dd9ff
3 changed files with 314 additions and 107 deletions

View File

@ -28,12 +28,9 @@ script:
- cargo build --features sqlcipher - cargo build --features sqlcipher
- cargo build --features "bundled sqlcipher" - cargo build --features "bundled sqlcipher"
- cargo test - cargo test
- cargo test --features backup - cargo test --features "backup blob"
- cargo test --features blob - cargo test --features "collation functions"
- cargo test --features collation - cargo test --features "hooks limits"
- cargo test --features functions
- cargo test --features hooks
- cargo test --features limits
- cargo test --features load_extension - cargo test --features load_extension
- cargo test --features trace - cargo test --features trace
- cargo test --features chrono - cargo test --features chrono
@ -43,7 +40,7 @@ script:
- cargo test --features sqlcipher - cargo test --features sqlcipher
- cargo test --features i128_blob - cargo test --features i128_blob
- cargo test --features uuid - cargo test --features uuid
- cargo test --features "unlock_notify bundled" - cargo test --features "bundled session unlock_notify window"
- cargo test --features "array bundled csvtab vtab" - cargo test --features "array bundled csvtab vtab"
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab" - cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab"
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen" - cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen"

View File

@ -49,6 +49,8 @@ csvtab = ["csv", "vtab"]
array = ["vtab"] array = ["vtab"]
# session extension: 3.13.0 # session extension: 3.13.0
session = ["libsqlite3-sys/session", "hooks"] session = ["libsqlite3-sys/session", "hooks"]
# window functions: 3.25.0
window = ["functions"]
[dependencies] [dependencies]
time = "0.1.0" time = "0.1.0"

View File

@ -226,6 +226,22 @@ where
fn finalize(&self, _: Option<A>) -> Result<T>; fn finalize(&self, _: Option<A>) -> Result<T>;
} }
/// WindowAggregate is the callback interface for user-defined aggregate window
/// function.
#[cfg(feature = "window")]
pub trait WindowAggregate<A, T>: Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
{
/// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context.
fn value(&self, _: Option<&A>) -> Result<T>;
/// Removes a row from the current window.
fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
}
impl Connection { impl Connection {
/// Attach a user-defined scalar function to this database connection. /// Attach a user-defined scalar function to this database connection.
/// ///
@ -294,6 +310,24 @@ impl Connection {
.create_aggregate_function(fn_name, n_arg, deterministic, aggr) .create_aggregate_function(fn_name, n_arg, deterministic, aggr)
} }
#[cfg(feature = "window")]
pub fn create_window_function<A, W, T>(
&self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
aggr: W,
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
self.db
.borrow_mut()
.create_window_function(fn_name, n_arg, deterministic, aggr)
}
/// Removes a user-defined function from this database connection. /// Removes a user-defined function from this database connection.
/// ///
/// `fn_name` and `n_arg` should match the name and number of arguments /// `fn_name` and `n_arg` should match the name and number of arguments
@ -386,10 +420,84 @@ impl InnerConnection {
D: Aggregate<A, T>, D: Aggregate<A, T>,
T: ToSql, T: ToSql,
{ {
unsafe fn aggregate_context<A>( let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
ctx: *mut sqlite3_context, let c_name = str_to_cstring(fn_name)?;
bytes: usize, let mut flags = ffi::SQLITE_UTF8;
) -> Option<*mut *mut A> { if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe {
ffi::sqlite3_create_function_v2(
self.db(),
c_name.as_ptr(),
n_arg,
flags,
boxed_aggr as *mut c_void,
None,
Some(call_boxed_step::<A, D, T>),
Some(call_boxed_final::<A, D, T>),
Some(free_boxed_value::<D>),
)
};
self.decode_result(r)
}
#[cfg(feature = "window")]
fn create_window_function<A, W, T>(
&mut self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
aggr: W,
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe {
ffi::sqlite3_create_window_function(
self.db(),
c_name.as_ptr(),
n_arg,
flags,
boxed_aggr as *mut c_void,
Some(call_boxed_step::<A, W, T>),
Some(call_boxed_final::<A, W, T>),
Some(call_boxed_value::<A, W, T>),
Some(call_boxed_inverse::<A, W, T>),
Some(free_boxed_value::<W>),
)
};
self.decode_result(r)
}
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
let c_name = str_to_cstring(fn_name)?;
let r = unsafe {
ffi::sqlite3_create_function_v2(
self.db(),
c_name.as_ptr(),
n_arg,
ffi::SQLITE_UTF8,
ptr::null_mut(),
None,
None,
None,
None,
)
};
self.decode_result(r)
}
}
unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A; let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
if pac.is_null() { if pac.is_null() {
return None; return None;
@ -442,6 +550,49 @@ impl InnerConnection {
}; };
} }
#[cfg(feature = "window")]
unsafe extern "C" fn call_boxed_inverse<A, W, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
Some(pac) => pac,
None => {
ffi::sqlite3_result_error_nomem(ctx);
return;
}
};
let r = catch_unwind(|| {
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
(*boxed_aggr).inverse(&mut ctx, &mut **pac)
});
let r = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
match r {
Ok(_) => {}
Err(err) => report_error(ctx, &err),
};
}
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context) unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where where
A: RefUnwindSafe + UnwindSafe, A: RefUnwindSafe + UnwindSafe,
@ -485,44 +636,47 @@ impl InnerConnection {
} }
} }
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr)); #[cfg(feature = "window")]
let c_name = str_to_cstring(fn_name)?; unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
let mut flags = ffi::SQLITE_UTF8; where
if deterministic { A: RefUnwindSafe + UnwindSafe,
flags |= ffi::SQLITE_DETERMINISTIC; W: WindowAggregate<A, T>,
T: ToSql,
{
// Within the xValue callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a: Option<&A> = match aggregate_context(ctx, 0) {
Some(pac) => {
if (*pac as *mut A).is_null() {
None
} else {
let a = &**pac;
Some(a)
} }
let r = unsafe { }
ffi::sqlite3_create_function_v2( None => None,
self.db(),
c_name.as_ptr(),
n_arg,
flags,
boxed_aggr as *mut c_void,
None,
Some(call_boxed_step::<A, D, T>),
Some(call_boxed_final::<A, D, T>),
Some(free_boxed_value::<D>),
)
}; };
self.decode_result(r)
}
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> { let r = catch_unwind(|| {
let c_name = str_to_cstring(fn_name)?; let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
let r = unsafe { assert!(
ffi::sqlite3_create_function_v2( !boxed_aggr.is_null(),
self.db(), "Internal error - null aggregate pointer"
c_name.as_ptr(), );
n_arg, (*boxed_aggr).value(a)
ffi::SQLITE_UTF8, });
ptr::null_mut(), let t = match r {
None, Err(_) => {
None, report_error(ctx, &Error::UnwindingPanic);
None, return;
None, }
) Ok(r) => r,
}; };
self.decode_result(r) let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
} }
} }
@ -534,7 +688,7 @@ mod test {
use std::f64::EPSILON; use std::f64::EPSILON;
use std::os::raw::c_double; use std::os::raw::c_double;
use crate::functions::{Aggregate, Context}; use crate::functions::{Aggregate, Context, WindowAggregate};
use crate::{Connection, Error, Result, NO_PARAMS}; use crate::{Connection, Error, Result, NO_PARAMS};
fn half(ctx: &Context<'_>) -> Result<c_double> { fn half(ctx: &Context<'_>) -> Result<c_double> {
@ -752,4 +906,58 @@ mod test {
let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap(); let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
assert_eq!(2, result); assert_eq!(2, result);
} }
#[cfg(feature = "window")]
impl WindowAggregate<i64, Option<i64>> for Sum {
fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
*sum -= ctx.get::<i64>(0)?;
Ok(())
}
fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
Ok(sum.copied())
}
}
#[test]
#[cfg(feature = "window")]
fn test_window() {
use fallible_iterator::FallibleIterator;
let db = Connection::open_in_memory().unwrap();
db.create_window_function("sumint", 1, true, Sum).unwrap();
db.execute_batch(
"CREATE TABLE t3(x, y);
INSERT INTO t3 VALUES('a', 4),
('b', 5),
('c', 3),
('d', 8),
('e', 1);",
)
.unwrap();
let mut stmt = db
.prepare(
"SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM t3 ORDER BY x;",
)
.unwrap();
let results: Vec<(String, i64)> = stmt
.query(NO_PARAMS)
.unwrap()
.map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
.collect()
.unwrap();
let expected = vec![
("a".to_owned(), 9),
("b".to_owned(), 12),
("c".to_owned(), 16),
("d".to_owned(), 12),
("e".to_owned(), 9),
];
assert_eq!(expected, results);
}
} }