Add binding to sqlite3_collation_needed

This commit is contained in:
gwenn 2019-06-18 19:03:13 +02:00
parent 4988715932
commit f6199df9f9

View File

@ -1,6 +1,6 @@
//! Add, remove, or modify a collation
use std::cmp::Ordering;
use std::os::raw::{c_int, c_void};
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe};
use std::ptr;
use std::slice;
@ -8,8 +8,6 @@ use std::slice;
use crate::ffi;
use crate::{str_to_cstring, Connection, InnerConnection, Result};
// TODO sqlite3_collation_needed https://sqlite.org/c3ref/collation_needed.html
// FIXME copy/paste from function.rs
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
drop(Box::from_raw(p as *mut T));
@ -26,6 +24,14 @@ impl Connection {
.create_collation(collation_name, x_compare)
}
/// Collation needed callback
pub fn collation_needed(
&self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
self.db.borrow_mut().collation_needed(x_coll_needed)
}
/// Remove collation.
pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
self.db.borrow_mut().remove_collation(collation_name)
@ -37,7 +43,7 @@ impl InnerConnection {
where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
{
unsafe extern "C" fn call_boxed_closure<F>(
unsafe extern "C" fn call_boxed_closure<C>(
arg1: *mut c_void,
arg2: c_int,
arg3: *const c_void,
@ -45,12 +51,12 @@ impl InnerConnection {
arg5: *const c_void,
) -> c_int
where
F: Fn(&str, &str) -> Ordering,
C: Fn(&str, &str) -> Ordering,
{
use std::str;
let r = catch_unwind(|| {
let boxed_f: *mut F = arg1 as *mut F;
let boxed_f: *mut C = arg1 as *mut C;
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let s1 = {
let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
@ -92,6 +98,48 @@ impl InnerConnection {
self.decode_result(r)
}
fn collation_needed(
&mut self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
use std::mem;
unsafe extern "C" fn collation_needed_callback(
arg1: *mut c_void,
arg2: *mut ffi::sqlite3,
e_text_rep: c_int,
arg3: *const c_char,
) {
use std::ffi::CStr;
use std::str;
if e_text_rep != ffi::SQLITE_UTF8 {
// TODO: validate
return;
}
let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
if let Err(_) = catch_unwind(|| {
let conn = Connection::from_handle(arg2).unwrap();
let collation_name = {
let c_slice = CStr::from_ptr(arg3).to_bytes();
str::from_utf8_unchecked(c_slice)
};
callback(&conn, collation_name)
}) {
return; // FIXME How ?
}
}
let r = unsafe {
ffi::sqlite3_collation_needed(
self.db(),
mem::transmute(x_coll_needed),
Some(collation_needed_callback),
)
};
self.decode_result(r)
}
fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
let c_name = str_to_cstring(collation_name)?;
let r = unsafe {
@ -110,7 +158,7 @@ impl InnerConnection {
#[cfg(test)]
mod test {
use crate::{Connection, NO_PARAMS};
use crate::{Connection, Result, NO_PARAMS};
use fallible_streaming_iterator::FallibleStreamingIterator;
use std::cmp::Ordering;
use unicase::UniCase;
@ -125,6 +173,10 @@ mod test {
db.create_collation("unicase", unicase_compare).unwrap();
collate(db);
}
fn collate(db: Connection) {
db.execute_batch(
"CREATE TABLE foo (bar);
INSERT INTO foo (bar) VALUES ('Maße');
@ -137,4 +189,19 @@ mod test {
let rows = stmt.query(NO_PARAMS).unwrap();
assert_eq!(rows.count().unwrap(), 1);
}
fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
if "unicase" == collation_name {
db.create_collation(collation_name, unicase_compare)
} else {
Ok(())
}
}
#[test]
fn test_collation_needed() {
let db = Connection::open_in_memory().unwrap();
db.collation_needed(collation_needed).unwrap();
collate(db);
}
}