rusqlite/src/collation.rs

216 lines
6.3 KiB
Rust
Raw Normal View History

2021-06-13 15:17:35 +08:00
//! Add, remove, or modify a collation
use std::cmp::Ordering;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe};
use std::ptr;
use std::slice;
use crate::ffi;
use crate::{str_to_cstring, Connection, InnerConnection, Result};
// FIXME copy/paste from function.rs
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
2022-01-06 02:50:25 +08:00
drop(Box::from_raw(p.cast::<T>()));
}
impl Connection {
2021-06-13 15:17:35 +08:00
/// Add or modify a collation.
#[inline]
2021-12-01 03:17:29 +08:00
pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
where
2021-12-01 03:17:29 +08:00
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
{
self.db
.borrow_mut()
.create_collation(collation_name, x_compare)
}
2021-06-13 15:17:35 +08:00
/// Collation needed callback
#[inline]
pub fn collation_needed(
&self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
self.db.borrow_mut().collation_needed(x_coll_needed)
}
2021-06-13 15:17:35 +08:00
/// Remove collation.
#[inline]
pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
self.db.borrow_mut().remove_collation(collation_name)
}
}
impl InnerConnection {
2021-12-01 03:17:29 +08:00
fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
where
2021-12-01 03:17:29 +08:00
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
{
unsafe extern "C" fn call_boxed_closure<C>(
arg1: *mut c_void,
arg2: c_int,
arg3: *const c_void,
arg4: c_int,
arg5: *const c_void,
) -> c_int
where
C: Fn(&str, &str) -> Ordering,
{
let r = catch_unwind(|| {
2022-01-06 02:50:25 +08:00
let boxed_f: *mut C = arg1.cast::<C>();
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let s1 = {
2022-01-06 02:50:25 +08:00
let c_slice = slice::from_raw_parts(arg3.cast::<u8>(), arg2 as usize);
String::from_utf8_lossy(c_slice)
};
let s2 = {
2022-01-06 02:50:25 +08:00
let c_slice = slice::from_raw_parts(arg5.cast::<u8>(), arg4 as usize);
String::from_utf8_lossy(c_slice)
};
(*boxed_f)(s1.as_ref(), s2.as_ref())
});
let t = match r {
Err(_) => {
return -1; // FIXME How ?
}
Ok(r) => r,
};
match t {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
}
let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
let c_name = str_to_cstring(collation_name)?;
let flags = ffi::SQLITE_UTF8;
let r = unsafe {
ffi::sqlite3_create_collation_v2(
self.db(),
c_name.as_ptr(),
flags,
2022-01-06 02:50:25 +08:00
boxed_f.cast::<c_void>(),
Some(call_boxed_closure::<C>),
Some(free_boxed_value::<C>),
)
};
let res = self.decode_result(r);
// The xDestroy callback is not called if the sqlite3_create_collation_v2()
// function fails.
if res.is_err() {
drop(unsafe { Box::from_raw(boxed_f) });
}
res
}
fn collation_needed(
&mut self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
use std::mem;
#[allow(clippy::needless_return)]
unsafe extern "C" fn collation_needed_callback(
arg1: *mut c_void,
arg2: *mut ffi::sqlite3,
e_text_rep: c_int,
arg3: *const c_char,
) {
use std::ffi::CStr;
use std::str;
if e_text_rep != ffi::SQLITE_UTF8 {
// TODO: validate
return;
}
let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
let res = catch_unwind(|| {
let conn = Connection::from_handle(arg2).unwrap();
let collation_name = {
let c_slice = CStr::from_ptr(arg3).to_bytes();
str::from_utf8(c_slice).expect("illegal collation sequence name")
};
callback(&conn, collation_name)
});
if res.is_err() {
return; // FIXME How ?
}
}
let r = unsafe {
ffi::sqlite3_collation_needed(
self.db(),
2020-08-18 02:07:56 +08:00
x_coll_needed as *mut c_void,
Some(collation_needed_callback),
)
};
self.decode_result(r)
}
#[inline]
fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
let c_name = str_to_cstring(collation_name)?;
let r = unsafe {
ffi::sqlite3_create_collation_v2(
self.db(),
c_name.as_ptr(),
ffi::SQLITE_UTF8,
ptr::null_mut(),
None,
None,
)
};
self.decode_result(r)
}
}
#[cfg(test)]
mod test {
use crate::{Connection, Result};
use fallible_streaming_iterator::FallibleStreamingIterator;
use std::cmp::Ordering;
use unicase::UniCase;
fn unicase_compare(s1: &str, s2: &str) -> Ordering {
UniCase::new(s1).cmp(&UniCase::new(s2))
}
#[test]
2020-11-06 05:14:00 +08:00
fn test_unicase() -> Result<()> {
let db = Connection::open_in_memory()?;
2020-11-06 05:14:00 +08:00
db.create_collation("unicase", unicase_compare)?;
2020-11-06 05:14:00 +08:00
collate(db)
}
2020-11-06 05:14:00 +08:00
fn collate(db: Connection) -> Result<()> {
db.execute_batch(
"CREATE TABLE foo (bar);
INSERT INTO foo (bar) VALUES ('Maße');
INSERT INTO foo (bar) VALUES ('MASSE');",
2020-11-06 05:14:00 +08:00
)?;
let mut stmt = db.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")?;
let rows = stmt.query([])?;
assert_eq!(rows.count()?, 1);
Ok(())
}
fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
if "unicase" == collation_name {
db.create_collation(collation_name, unicase_compare)
} else {
Ok(())
}
}
#[test]
2020-11-06 05:14:00 +08:00
fn test_collation_needed() -> Result<()> {
let db = Connection::open_in_memory()?;
db.collation_needed(collation_needed)?;
collate(db)
}
}