mirror of
https://github.com/isar/rusqlite.git
synced 2024-11-26 19:41:37 +08:00
Merge pull request #539 from gwenn/window-func
Add binding to `sqlite3_create_window_function`
This commit is contained in:
commit
ef036e383c
11
.travis.yml
11
.travis.yml
@ -28,12 +28,9 @@ script:
|
|||||||
- cargo build --features sqlcipher
|
- cargo build --features sqlcipher
|
||||||
- cargo build --features "bundled sqlcipher"
|
- cargo build --features "bundled sqlcipher"
|
||||||
- cargo test
|
- cargo test
|
||||||
- cargo test --features backup
|
- cargo test --features "backup blob"
|
||||||
- cargo test --features blob
|
- cargo test --features "collation functions"
|
||||||
- cargo test --features collation
|
- cargo test --features "hooks limits"
|
||||||
- cargo test --features functions
|
|
||||||
- cargo test --features hooks
|
|
||||||
- cargo test --features limits
|
|
||||||
- cargo test --features load_extension
|
- cargo test --features load_extension
|
||||||
- cargo test --features trace
|
- cargo test --features trace
|
||||||
- cargo test --features chrono
|
- cargo test --features chrono
|
||||||
@ -43,7 +40,7 @@ script:
|
|||||||
- cargo test --features sqlcipher
|
- cargo test --features sqlcipher
|
||||||
- cargo test --features i128_blob
|
- cargo test --features i128_blob
|
||||||
- cargo test --features uuid
|
- cargo test --features uuid
|
||||||
- cargo test --features "unlock_notify bundled"
|
- cargo test --features "bundled unlock_notify window"
|
||||||
- cargo test --features "array bundled csvtab vtab"
|
- cargo test --features "array bundled csvtab vtab"
|
||||||
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab"
|
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab"
|
||||||
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen"
|
- cargo test --features "backup blob chrono collation csvtab functions hooks limits load_extension serde_json trace url uuid vtab buildtime_bindgen"
|
||||||
|
@ -49,6 +49,8 @@ csvtab = ["csv", "vtab"]
|
|||||||
array = ["vtab"]
|
array = ["vtab"]
|
||||||
# session extension: 3.13.0
|
# session extension: 3.13.0
|
||||||
session = ["libsqlite3-sys/session", "hooks"]
|
session = ["libsqlite3-sys/session", "hooks"]
|
||||||
|
# window functions: 3.25.0
|
||||||
|
window = ["functions"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
time = "0.1.0"
|
time = "0.1.0"
|
||||||
|
304
src/functions.rs
304
src/functions.rs
@ -226,6 +226,22 @@ where
|
|||||||
fn finalize(&self, _: Option<A>) -> Result<T>;
|
fn finalize(&self, _: Option<A>) -> Result<T>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// WindowAggregate is the callback interface for user-defined aggregate window
|
||||||
|
/// function.
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
pub trait WindowAggregate<A, T>: Aggregate<A, T>
|
||||||
|
where
|
||||||
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
|
T: ToSql,
|
||||||
|
{
|
||||||
|
/// Returns the current value of the aggregate. Unlike xFinal, the
|
||||||
|
/// implementation should not delete any context.
|
||||||
|
fn value(&self, _: Option<&A>) -> Result<T>;
|
||||||
|
|
||||||
|
/// Removes a row from the current window.
|
||||||
|
fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
|
||||||
|
}
|
||||||
|
|
||||||
impl Connection {
|
impl Connection {
|
||||||
/// Attach a user-defined scalar function to this database connection.
|
/// Attach a user-defined scalar function to this database connection.
|
||||||
///
|
///
|
||||||
@ -294,6 +310,24 @@ impl Connection {
|
|||||||
.create_aggregate_function(fn_name, n_arg, deterministic, aggr)
|
.create_aggregate_function(fn_name, n_arg, deterministic, aggr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
pub fn create_window_function<A, W, T>(
|
||||||
|
&self,
|
||||||
|
fn_name: &str,
|
||||||
|
n_arg: c_int,
|
||||||
|
deterministic: bool,
|
||||||
|
aggr: W,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
|
W: WindowAggregate<A, T>,
|
||||||
|
T: ToSql,
|
||||||
|
{
|
||||||
|
self.db
|
||||||
|
.borrow_mut()
|
||||||
|
.create_window_function(fn_name, n_arg, deterministic, aggr)
|
||||||
|
}
|
||||||
|
|
||||||
/// Removes a user-defined function from this database connection.
|
/// Removes a user-defined function from this database connection.
|
||||||
///
|
///
|
||||||
/// `fn_name` and `n_arg` should match the name and number of arguments
|
/// `fn_name` and `n_arg` should match the name and number of arguments
|
||||||
@ -386,26 +420,100 @@ impl InnerConnection {
|
|||||||
D: Aggregate<A, T>,
|
D: Aggregate<A, T>,
|
||||||
T: ToSql,
|
T: ToSql,
|
||||||
{
|
{
|
||||||
unsafe fn aggregate_context<A>(
|
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
|
||||||
ctx: *mut sqlite3_context,
|
let c_name = str_to_cstring(fn_name)?;
|
||||||
bytes: usize,
|
let mut flags = ffi::SQLITE_UTF8;
|
||||||
) -> Option<*mut *mut A> {
|
if deterministic {
|
||||||
|
flags |= ffi::SQLITE_DETERMINISTIC;
|
||||||
|
}
|
||||||
|
let r = unsafe {
|
||||||
|
ffi::sqlite3_create_function_v2(
|
||||||
|
self.db(),
|
||||||
|
c_name.as_ptr(),
|
||||||
|
n_arg,
|
||||||
|
flags,
|
||||||
|
boxed_aggr as *mut c_void,
|
||||||
|
None,
|
||||||
|
Some(call_boxed_step::<A, D, T>),
|
||||||
|
Some(call_boxed_final::<A, D, T>),
|
||||||
|
Some(free_boxed_value::<D>),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.decode_result(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
fn create_window_function<A, W, T>(
|
||||||
|
&mut self,
|
||||||
|
fn_name: &str,
|
||||||
|
n_arg: c_int,
|
||||||
|
deterministic: bool,
|
||||||
|
aggr: W,
|
||||||
|
) -> Result<()>
|
||||||
|
where
|
||||||
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
|
W: WindowAggregate<A, T>,
|
||||||
|
T: ToSql,
|
||||||
|
{
|
||||||
|
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
|
||||||
|
let c_name = str_to_cstring(fn_name)?;
|
||||||
|
let mut flags = ffi::SQLITE_UTF8;
|
||||||
|
if deterministic {
|
||||||
|
flags |= ffi::SQLITE_DETERMINISTIC;
|
||||||
|
}
|
||||||
|
let r = unsafe {
|
||||||
|
ffi::sqlite3_create_window_function(
|
||||||
|
self.db(),
|
||||||
|
c_name.as_ptr(),
|
||||||
|
n_arg,
|
||||||
|
flags,
|
||||||
|
boxed_aggr as *mut c_void,
|
||||||
|
Some(call_boxed_step::<A, W, T>),
|
||||||
|
Some(call_boxed_final::<A, W, T>),
|
||||||
|
Some(call_boxed_value::<A, W, T>),
|
||||||
|
Some(call_boxed_inverse::<A, W, T>),
|
||||||
|
Some(free_boxed_value::<W>),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.decode_result(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
|
||||||
|
let c_name = str_to_cstring(fn_name)?;
|
||||||
|
let r = unsafe {
|
||||||
|
ffi::sqlite3_create_function_v2(
|
||||||
|
self.db(),
|
||||||
|
c_name.as_ptr(),
|
||||||
|
n_arg,
|
||||||
|
ffi::SQLITE_UTF8,
|
||||||
|
ptr::null_mut(),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
self.decode_result(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
|
||||||
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
|
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
|
||||||
if pac.is_null() {
|
if pac.is_null() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
Some(pac)
|
Some(pac)
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" fn call_boxed_step<A, D, T>(
|
unsafe extern "C" fn call_boxed_step<A, D, T>(
|
||||||
ctx: *mut sqlite3_context,
|
ctx: *mut sqlite3_context,
|
||||||
argc: c_int,
|
argc: c_int,
|
||||||
argv: *mut *mut sqlite3_value,
|
argv: *mut *mut sqlite3_value,
|
||||||
) where
|
) where
|
||||||
A: RefUnwindSafe + UnwindSafe,
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
D: Aggregate<A, T>,
|
D: Aggregate<A, T>,
|
||||||
T: ToSql,
|
T: ToSql,
|
||||||
{
|
{
|
||||||
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
|
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
|
||||||
Some(pac) => pac,
|
Some(pac) => pac,
|
||||||
None => {
|
None => {
|
||||||
@ -440,14 +548,57 @@ impl InnerConnection {
|
|||||||
Ok(_) => {}
|
Ok(_) => {}
|
||||||
Err(err) => report_error(ctx, &err),
|
Err(err) => report_error(ctx, &err),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
|
#[cfg(feature = "window")]
|
||||||
where
|
unsafe extern "C" fn call_boxed_inverse<A, W, T>(
|
||||||
|
ctx: *mut sqlite3_context,
|
||||||
|
argc: c_int,
|
||||||
|
argv: *mut *mut sqlite3_value,
|
||||||
|
) where
|
||||||
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
|
W: WindowAggregate<A, T>,
|
||||||
|
T: ToSql,
|
||||||
|
{
|
||||||
|
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
|
||||||
|
Some(pac) => pac,
|
||||||
|
None => {
|
||||||
|
ffi::sqlite3_result_error_nomem(ctx);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let r = catch_unwind(|| {
|
||||||
|
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
|
||||||
|
assert!(
|
||||||
|
!boxed_aggr.is_null(),
|
||||||
|
"Internal error - null aggregate pointer"
|
||||||
|
);
|
||||||
|
let mut ctx = Context {
|
||||||
|
ctx,
|
||||||
|
args: slice::from_raw_parts(argv, argc as usize),
|
||||||
|
};
|
||||||
|
(*boxed_aggr).inverse(&mut ctx, &mut **pac)
|
||||||
|
});
|
||||||
|
let r = match r {
|
||||||
|
Err(_) => {
|
||||||
|
report_error(ctx, &Error::UnwindingPanic);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(r) => r,
|
||||||
|
};
|
||||||
|
match r {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(err) => report_error(ctx, &err),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
|
||||||
|
where
|
||||||
A: RefUnwindSafe + UnwindSafe,
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
D: Aggregate<A, T>,
|
D: Aggregate<A, T>,
|
||||||
T: ToSql,
|
T: ToSql,
|
||||||
{
|
{
|
||||||
// Within the xFinal callback, it is customary to set N=0 in calls to
|
// Within the xFinal callback, it is customary to set N=0 in calls to
|
||||||
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
|
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
|
||||||
let a: Option<A> = match aggregate_context(ctx, 0) {
|
let a: Option<A> = match aggregate_context(ctx, 0) {
|
||||||
@ -483,46 +634,49 @@ impl InnerConnection {
|
|||||||
Ok(Err(err)) => report_error(ctx, &err),
|
Ok(Err(err)) => report_error(ctx, &err),
|
||||||
Err(err) => report_error(ctx, err),
|
Err(err) => report_error(ctx, err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
|
#[cfg(feature = "window")]
|
||||||
let c_name = str_to_cstring(fn_name)?;
|
unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
|
||||||
let mut flags = ffi::SQLITE_UTF8;
|
where
|
||||||
if deterministic {
|
A: RefUnwindSafe + UnwindSafe,
|
||||||
flags |= ffi::SQLITE_DETERMINISTIC;
|
W: WindowAggregate<A, T>,
|
||||||
|
T: ToSql,
|
||||||
|
{
|
||||||
|
// Within the xValue callback, it is customary to set N=0 in calls to
|
||||||
|
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
|
||||||
|
let a: Option<&A> = match aggregate_context(ctx, 0) {
|
||||||
|
Some(pac) => {
|
||||||
|
if (*pac as *mut A).is_null() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let a = &**pac;
|
||||||
|
Some(a)
|
||||||
}
|
}
|
||||||
let r = unsafe {
|
}
|
||||||
ffi::sqlite3_create_function_v2(
|
None => None,
|
||||||
self.db(),
|
|
||||||
c_name.as_ptr(),
|
|
||||||
n_arg,
|
|
||||||
flags,
|
|
||||||
boxed_aggr as *mut c_void,
|
|
||||||
None,
|
|
||||||
Some(call_boxed_step::<A, D, T>),
|
|
||||||
Some(call_boxed_final::<A, D, T>),
|
|
||||||
Some(free_boxed_value::<D>),
|
|
||||||
)
|
|
||||||
};
|
};
|
||||||
self.decode_result(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
|
let r = catch_unwind(|| {
|
||||||
let c_name = str_to_cstring(fn_name)?;
|
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
|
||||||
let r = unsafe {
|
assert!(
|
||||||
ffi::sqlite3_create_function_v2(
|
!boxed_aggr.is_null(),
|
||||||
self.db(),
|
"Internal error - null aggregate pointer"
|
||||||
c_name.as_ptr(),
|
);
|
||||||
n_arg,
|
(*boxed_aggr).value(a)
|
||||||
ffi::SQLITE_UTF8,
|
});
|
||||||
ptr::null_mut(),
|
let t = match r {
|
||||||
None,
|
Err(_) => {
|
||||||
None,
|
report_error(ctx, &Error::UnwindingPanic);
|
||||||
None,
|
return;
|
||||||
None,
|
}
|
||||||
)
|
Ok(r) => r,
|
||||||
};
|
};
|
||||||
self.decode_result(r)
|
let t = t.as_ref().map(|t| ToSql::to_sql(t));
|
||||||
|
match t {
|
||||||
|
Ok(Ok(ref value)) => set_result(ctx, value),
|
||||||
|
Ok(Err(err)) => report_error(ctx, &err),
|
||||||
|
Err(err) => report_error(ctx, err),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -535,6 +689,8 @@ mod test {
|
|||||||
use std::os::raw::c_double;
|
use std::os::raw::c_double;
|
||||||
|
|
||||||
use crate::functions::{Aggregate, Context};
|
use crate::functions::{Aggregate, Context};
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
use crate::functions::WindowAggregate;
|
||||||
use crate::{Connection, Error, Result, NO_PARAMS};
|
use crate::{Connection, Error, Result, NO_PARAMS};
|
||||||
|
|
||||||
fn half(ctx: &Context<'_>) -> Result<c_double> {
|
fn half(ctx: &Context<'_>) -> Result<c_double> {
|
||||||
@ -752,4 +908,58 @@ mod test {
|
|||||||
let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
|
let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
|
||||||
assert_eq!(2, result);
|
assert_eq!(2, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
impl WindowAggregate<i64, Option<i64>> for Sum {
|
||||||
|
fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
|
||||||
|
*sum -= ctx.get::<i64>(0)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
|
||||||
|
Ok(sum.copied())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[cfg(feature = "window")]
|
||||||
|
fn test_window() {
|
||||||
|
use fallible_iterator::FallibleIterator;
|
||||||
|
|
||||||
|
let db = Connection::open_in_memory().unwrap();
|
||||||
|
db.create_window_function("sumint", 1, true, Sum).unwrap();
|
||||||
|
db.execute_batch(
|
||||||
|
"CREATE TABLE t3(x, y);
|
||||||
|
INSERT INTO t3 VALUES('a', 4),
|
||||||
|
('b', 5),
|
||||||
|
('c', 3),
|
||||||
|
('d', 8),
|
||||||
|
('e', 1);",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut stmt = db
|
||||||
|
.prepare(
|
||||||
|
"SELECT x, sumint(y) OVER (
|
||||||
|
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
|
||||||
|
) AS sum_y
|
||||||
|
FROM t3 ORDER BY x;",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results: Vec<(String, i64)> = stmt
|
||||||
|
.query(NO_PARAMS)
|
||||||
|
.unwrap()
|
||||||
|
.map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
|
||||||
|
.collect()
|
||||||
|
.unwrap();
|
||||||
|
let expected = vec![
|
||||||
|
("a".to_owned(), 9),
|
||||||
|
("b".to_owned(), 12),
|
||||||
|
("c".to_owned(), 16),
|
||||||
|
("d".to_owned(), 12),
|
||||||
|
("e".to_owned(), 9),
|
||||||
|
];
|
||||||
|
assert_eq!(expected, results);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user