From f6199df9f923ff4baee1beac4971e7afa1bf6cec Mon Sep 17 00:00:00 2001 From: gwenn Date: Tue, 18 Jun 2019 19:03:13 +0200 Subject: [PATCH] Add binding to `sqlite3_collation_needed` --- src/collation.rs | 81 +++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 7 deletions(-) diff --git a/src/collation.rs b/src/collation.rs index 03e9837..43ea660 100644 --- a/src/collation.rs +++ b/src/collation.rs @@ -1,6 +1,6 @@ //! Add, remove, or modify a collation use std::cmp::Ordering; -use std::os::raw::{c_int, c_void}; +use std::os::raw::{c_char, c_int, c_void}; use std::panic::{catch_unwind, UnwindSafe}; use std::ptr; use std::slice; @@ -8,8 +8,6 @@ use std::slice; use crate::ffi; use crate::{str_to_cstring, Connection, InnerConnection, Result}; -// TODO sqlite3_collation_needed https://sqlite.org/c3ref/collation_needed.html - // FIXME copy/paste from function.rs unsafe extern "C" fn free_boxed_value(p: *mut c_void) { drop(Box::from_raw(p as *mut T)); @@ -26,6 +24,14 @@ impl Connection { .create_collation(collation_name, x_compare) } + /// Collation needed callback + pub fn collation_needed( + &self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + self.db.borrow_mut().collation_needed(x_coll_needed) + } + /// Remove collation. pub fn remove_collation(&self, collation_name: &str) -> Result<()> { self.db.borrow_mut().remove_collation(collation_name) @@ -37,7 +43,7 @@ impl InnerConnection { where C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static, { - unsafe extern "C" fn call_boxed_closure( + unsafe extern "C" fn call_boxed_closure( arg1: *mut c_void, arg2: c_int, arg3: *const c_void, @@ -45,12 +51,12 @@ impl InnerConnection { arg5: *const c_void, ) -> c_int where - F: Fn(&str, &str) -> Ordering, + C: Fn(&str, &str) -> Ordering, { use std::str; let r = catch_unwind(|| { - let boxed_f: *mut F = arg1 as *mut F; + let boxed_f: *mut C = arg1 as *mut C; assert!(!boxed_f.is_null(), "Internal error - null function pointer"); let s1 = { let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize); @@ -92,6 +98,48 @@ impl InnerConnection { self.decode_result(r) } + fn collation_needed( + &mut self, + x_coll_needed: fn(&Connection, &str) -> Result<()>, + ) -> Result<()> { + use std::mem; + unsafe extern "C" fn collation_needed_callback( + arg1: *mut c_void, + arg2: *mut ffi::sqlite3, + e_text_rep: c_int, + arg3: *const c_char, + ) { + use std::ffi::CStr; + use std::str; + + if e_text_rep != ffi::SQLITE_UTF8 { + // TODO: validate + return; + } + + let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1); + if let Err(_) = catch_unwind(|| { + let conn = Connection::from_handle(arg2).unwrap(); + let collation_name = { + let c_slice = CStr::from_ptr(arg3).to_bytes(); + str::from_utf8_unchecked(c_slice) + }; + callback(&conn, collation_name) + }) { + return; // FIXME How ? + } + } + + let r = unsafe { + ffi::sqlite3_collation_needed( + self.db(), + mem::transmute(x_coll_needed), + Some(collation_needed_callback), + ) + }; + self.decode_result(r) + } + fn remove_collation(&mut self, collation_name: &str) -> Result<()> { let c_name = str_to_cstring(collation_name)?; let r = unsafe { @@ -110,7 +158,7 @@ impl InnerConnection { #[cfg(test)] mod test { - use crate::{Connection, NO_PARAMS}; + use crate::{Connection, Result, NO_PARAMS}; use fallible_streaming_iterator::FallibleStreamingIterator; use std::cmp::Ordering; use unicase::UniCase; @@ -125,6 +173,10 @@ mod test { db.create_collation("unicase", unicase_compare).unwrap(); + collate(db); + } + + fn collate(db: Connection) { db.execute_batch( "CREATE TABLE foo (bar); INSERT INTO foo (bar) VALUES ('Maße'); @@ -137,4 +189,19 @@ mod test { let rows = stmt.query(NO_PARAMS).unwrap(); assert_eq!(rows.count().unwrap(), 1); } + + fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> { + if "unicase" == collation_name { + db.create_collation(collation_name, unicase_compare) + } else { + Ok(()) + } + } + + #[test] + fn test_collation_needed() { + let db = Connection::open_in_memory().unwrap(); + db.collation_needed(collation_needed).unwrap(); + collate(db); + } }