mirror of
https://github.com/isar/rusqlite.git
synced 2024-11-23 00:39:20 +08:00
Merge pull request #539 from gwenn/window-func
Add binding to `sqlite3_create_window_function`
This commit is contained in:
commit
ef036e383c
11
.travis.yml
11
.travis.yml
@ -28,12 +28,9 @@ script:
|
||||
- cargo build --features sqlcipher
|
||||
- cargo build --features "bundled sqlcipher"
|
||||
- cargo test
|
||||
- cargo test --features backup
|
||||
- cargo test --features blob
|
||||
- cargo test --features collation
|
||||
- cargo test --features functions
|
||||
- cargo test --features hooks
|
||||
- cargo test --features limits
|
||||
- cargo test --features "backup blob"
|
||||
- cargo test --features "collation functions"
|
||||
- cargo test --features "hooks limits"
|
||||
- cargo test --features load_extension
|
||||
- cargo test --features trace
|
||||
- cargo test --features chrono
|
||||
@ -43,7 +40,7 @@ script:
|
||||
- cargo test --features sqlcipher
|
||||
- cargo test --features i128_blob
|
||||
- cargo test --features uuid
|
||||
- cargo test --features "unlock_notify bundled"
|
||||
- cargo test --features "bundled unlock_notify window"
|
||||
- cargo test --features "array bundled csvtab vtab"
|
||||
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab"
|
||||
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen"
|
||||
|
@ -49,6 +49,8 @@ csvtab = ["csv", "vtab"]
|
||||
array = ["vtab"]
|
||||
# session extension: 3.13.0
|
||||
session = ["libsqlite3-sys/session", "hooks"]
|
||||
# window functions: 3.25.0
|
||||
window = ["functions"]
|
||||
|
||||
[dependencies]
|
||||
time = "0.1.0"
|
||||
|
304
src/functions.rs
304
src/functions.rs
@ -226,6 +226,22 @@ where
|
||||
fn finalize(&self, _: Option<A>) -> Result<T>;
|
||||
}
|
||||
|
||||
/// WindowAggregate is the callback interface for user-defined aggregate window
|
||||
/// function.
|
||||
#[cfg(feature = "window")]
|
||||
pub trait WindowAggregate<A, T>: Aggregate<A, T>
|
||||
where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
T: ToSql,
|
||||
{
|
||||
/// Returns the current value of the aggregate. Unlike xFinal, the
|
||||
/// implementation should not delete any context.
|
||||
fn value(&self, _: Option<&A>) -> Result<T>;
|
||||
|
||||
/// Removes a row from the current window.
|
||||
fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Attach a user-defined scalar function to this database connection.
|
||||
///
|
||||
@ -294,6 +310,24 @@ impl Connection {
|
||||
.create_aggregate_function(fn_name, n_arg, deterministic, aggr)
|
||||
}
|
||||
|
||||
#[cfg(feature = "window")]
|
||||
pub fn create_window_function<A, W, T>(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
n_arg: c_int,
|
||||
deterministic: bool,
|
||||
aggr: W,
|
||||
) -> Result<()>
|
||||
where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
W: WindowAggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
self.db
|
||||
.borrow_mut()
|
||||
.create_window_function(fn_name, n_arg, deterministic, aggr)
|
||||
}
|
||||
|
||||
/// Removes a user-defined function from this database connection.
|
||||
///
|
||||
/// `fn_name` and `n_arg` should match the name and number of arguments
|
||||
@ -386,26 +420,100 @@ impl InnerConnection {
|
||||
D: Aggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
unsafe fn aggregate_context<A>(
|
||||
ctx: *mut sqlite3_context,
|
||||
bytes: usize,
|
||||
) -> Option<*mut *mut A> {
|
||||
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
|
||||
let c_name = str_to_cstring(fn_name)?;
|
||||
let mut flags = ffi::SQLITE_UTF8;
|
||||
if deterministic {
|
||||
flags |= ffi::SQLITE_DETERMINISTIC;
|
||||
}
|
||||
let r = unsafe {
|
||||
ffi::sqlite3_create_function_v2(
|
||||
self.db(),
|
||||
c_name.as_ptr(),
|
||||
n_arg,
|
||||
flags,
|
||||
boxed_aggr as *mut c_void,
|
||||
None,
|
||||
Some(call_boxed_step::<A, D, T>),
|
||||
Some(call_boxed_final::<A, D, T>),
|
||||
Some(free_boxed_value::<D>),
|
||||
)
|
||||
};
|
||||
self.decode_result(r)
|
||||
}
|
||||
|
||||
#[cfg(feature = "window")]
|
||||
fn create_window_function<A, W, T>(
|
||||
&mut self,
|
||||
fn_name: &str,
|
||||
n_arg: c_int,
|
||||
deterministic: bool,
|
||||
aggr: W,
|
||||
) -> Result<()>
|
||||
where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
W: WindowAggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
|
||||
let c_name = str_to_cstring(fn_name)?;
|
||||
let mut flags = ffi::SQLITE_UTF8;
|
||||
if deterministic {
|
||||
flags |= ffi::SQLITE_DETERMINISTIC;
|
||||
}
|
||||
let r = unsafe {
|
||||
ffi::sqlite3_create_window_function(
|
||||
self.db(),
|
||||
c_name.as_ptr(),
|
||||
n_arg,
|
||||
flags,
|
||||
boxed_aggr as *mut c_void,
|
||||
Some(call_boxed_step::<A, W, T>),
|
||||
Some(call_boxed_final::<A, W, T>),
|
||||
Some(call_boxed_value::<A, W, T>),
|
||||
Some(call_boxed_inverse::<A, W, T>),
|
||||
Some(free_boxed_value::<W>),
|
||||
)
|
||||
};
|
||||
self.decode_result(r)
|
||||
}
|
||||
|
||||
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
|
||||
let c_name = str_to_cstring(fn_name)?;
|
||||
let r = unsafe {
|
||||
ffi::sqlite3_create_function_v2(
|
||||
self.db(),
|
||||
c_name.as_ptr(),
|
||||
n_arg,
|
||||
ffi::SQLITE_UTF8,
|
||||
ptr::null_mut(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
};
|
||||
self.decode_result(r)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
|
||||
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
|
||||
if pac.is_null() {
|
||||
return None;
|
||||
}
|
||||
Some(pac)
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn call_boxed_step<A, D, T>(
|
||||
unsafe extern "C" fn call_boxed_step<A, D, T>(
|
||||
ctx: *mut sqlite3_context,
|
||||
argc: c_int,
|
||||
argv: *mut *mut sqlite3_value,
|
||||
) where
|
||||
) where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
D: Aggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
{
|
||||
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
|
||||
Some(pac) => pac,
|
||||
None => {
|
||||
@ -440,14 +548,57 @@ impl InnerConnection {
|
||||
Ok(_) => {}
|
||||
Err(err) => report_error(ctx, &err),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
|
||||
where
|
||||
#[cfg(feature = "window")]
|
||||
unsafe extern "C" fn call_boxed_inverse<A, W, T>(
|
||||
ctx: *mut sqlite3_context,
|
||||
argc: c_int,
|
||||
argv: *mut *mut sqlite3_value,
|
||||
) where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
W: WindowAggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
|
||||
Some(pac) => pac,
|
||||
None => {
|
||||
ffi::sqlite3_result_error_nomem(ctx);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let r = catch_unwind(|| {
|
||||
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
|
||||
assert!(
|
||||
!boxed_aggr.is_null(),
|
||||
"Internal error - null aggregate pointer"
|
||||
);
|
||||
let mut ctx = Context {
|
||||
ctx,
|
||||
args: slice::from_raw_parts(argv, argc as usize),
|
||||
};
|
||||
(*boxed_aggr).inverse(&mut ctx, &mut **pac)
|
||||
});
|
||||
let r = match r {
|
||||
Err(_) => {
|
||||
report_error(ctx, &Error::UnwindingPanic);
|
||||
return;
|
||||
}
|
||||
Ok(r) => r,
|
||||
};
|
||||
match r {
|
||||
Ok(_) => {}
|
||||
Err(err) => report_error(ctx, &err),
|
||||
};
|
||||
}
|
||||
|
||||
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
|
||||
where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
D: Aggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
{
|
||||
// Within the xFinal callback, it is customary to set N=0 in calls to
|
||||
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
|
||||
let a: Option<A> = match aggregate_context(ctx, 0) {
|
||||
@ -483,46 +634,49 @@ impl InnerConnection {
|
||||
Ok(Err(err)) => report_error(ctx, &err),
|
||||
Err(err) => report_error(ctx, err),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
|
||||
let c_name = str_to_cstring(fn_name)?;
|
||||
let mut flags = ffi::SQLITE_UTF8;
|
||||
if deterministic {
|
||||
flags |= ffi::SQLITE_DETERMINISTIC;
|
||||
#[cfg(feature = "window")]
|
||||
unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
|
||||
where
|
||||
A: RefUnwindSafe + UnwindSafe,
|
||||
W: WindowAggregate<A, T>,
|
||||
T: ToSql,
|
||||
{
|
||||
// Within the xValue callback, it is customary to set N=0 in calls to
|
||||
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
|
||||
let a: Option<&A> = match aggregate_context(ctx, 0) {
|
||||
Some(pac) => {
|
||||
if (*pac as *mut A).is_null() {
|
||||
None
|
||||
} else {
|
||||
let a = &**pac;
|
||||
Some(a)
|
||||
}
|
||||
let r = unsafe {
|
||||
ffi::sqlite3_create_function_v2(
|
||||
self.db(),
|
||||
c_name.as_ptr(),
|
||||
n_arg,
|
||||
flags,
|
||||
boxed_aggr as *mut c_void,
|
||||
None,
|
||||
Some(call_boxed_step::<A, D, T>),
|
||||
Some(call_boxed_final::<A, D, T>),
|
||||
Some(free_boxed_value::<D>),
|
||||
)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
self.decode_result(r)
|
||||
}
|
||||
|
||||
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
|
||||
let c_name = str_to_cstring(fn_name)?;
|
||||
let r = unsafe {
|
||||
ffi::sqlite3_create_function_v2(
|
||||
self.db(),
|
||||
c_name.as_ptr(),
|
||||
n_arg,
|
||||
ffi::SQLITE_UTF8,
|
||||
ptr::null_mut(),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
let r = catch_unwind(|| {
|
||||
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
|
||||
assert!(
|
||||
!boxed_aggr.is_null(),
|
||||
"Internal error - null aggregate pointer"
|
||||
);
|
||||
(*boxed_aggr).value(a)
|
||||
});
|
||||
let t = match r {
|
||||
Err(_) => {
|
||||
report_error(ctx, &Error::UnwindingPanic);
|
||||
return;
|
||||
}
|
||||
Ok(r) => r,
|
||||
};
|
||||
self.decode_result(r)
|
||||
let t = t.as_ref().map(|t| ToSql::to_sql(t));
|
||||
match t {
|
||||
Ok(Ok(ref value)) => set_result(ctx, value),
|
||||
Ok(Err(err)) => report_error(ctx, &err),
|
||||
Err(err) => report_error(ctx, err),
|
||||
}
|
||||
}
|
||||
|
||||
@ -535,6 +689,8 @@ mod test {
|
||||
use std::os::raw::c_double;
|
||||
|
||||
use crate::functions::{Aggregate, Context};
|
||||
#[cfg(feature = "window")]
|
||||
use crate::functions::WindowAggregate;
|
||||
use crate::{Connection, Error, Result, NO_PARAMS};
|
||||
|
||||
fn half(ctx: &Context<'_>) -> Result<c_double> {
|
||||
@ -752,4 +908,58 @@ mod test {
|
||||
let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
|
||||
assert_eq!(2, result);
|
||||
}
|
||||
|
||||
#[cfg(feature = "window")]
|
||||
impl WindowAggregate<i64, Option<i64>> for Sum {
|
||||
fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
|
||||
*sum -= ctx.get::<i64>(0)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
|
||||
Ok(sum.copied())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "window")]
|
||||
fn test_window() {
|
||||
use fallible_iterator::FallibleIterator;
|
||||
|
||||
let db = Connection::open_in_memory().unwrap();
|
||||
db.create_window_function("sumint", 1, true, Sum).unwrap();
|
||||
db.execute_batch(
|
||||
"CREATE TABLE t3(x, y);
|
||||
INSERT INTO t3 VALUES('a', 4),
|
||||
('b', 5),
|
||||
('c', 3),
|
||||
('d', 8),
|
||||
('e', 1);",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut stmt = db
|
||||
.prepare(
|
||||
"SELECT x, sumint(y) OVER (
|
||||
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
|
||||
) AS sum_y
|
||||
FROM t3 ORDER BY x;",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let results: Vec<(String, i64)> = stmt
|
||||
.query(NO_PARAMS)
|
||||
.unwrap()
|
||||
.map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
|
||||
.collect()
|
||||
.unwrap();
|
||||
let expected = vec![
|
||||
("a".to_owned(), 9),
|
||||
("b".to_owned(), 12),
|
||||
("c".to_owned(), 16),
|
||||
("d".to_owned(), 12),
|
||||
("e".to_owned(), 9),
|
||||
];
|
||||
assert_eq!(expected, results);
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user