Introduce alloc to generate C string allocated by sqlite3

Safe to send to SQLite for deallocation.
This commit is contained in:
gwenn 2020-02-29 13:02:37 +01:00
parent a3e5ea990d
commit 5356a609de
2 changed files with 31 additions and 25 deletions

View File

@ -372,25 +372,19 @@ rusqlite was built against SQLite {} but the runtime SQLite version is {}. To fi
}); });
} }
#[cfg(not(any( #[cfg(not(any(target_arch = "wasm32")))]
target_arch = "wasm32"
)))]
static SQLITE_INIT: std::sync::Once = std::sync::Once::new(); static SQLITE_INIT: std::sync::Once = std::sync::Once::new();
pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false);
// threading mode checks are not necessary (and do not work) on target // threading mode checks are not necessary (and do not work) on target
// platforms that do not have threading (such as webassembly) // platforms that do not have threading (such as webassembly)
#[cfg(any( #[cfg(any(target_arch = "wasm32"))]
target_arch = "wasm32"
))]
fn ensure_safe_sqlite_threading_mode() -> Result<()> { fn ensure_safe_sqlite_threading_mode() -> Result<()> {
Ok(()) Ok(())
} }
#[cfg(not(any( #[cfg(not(any(target_arch = "wasm32")))]
target_arch = "wasm32"
)))]
fn ensure_safe_sqlite_threading_mode() -> Result<()> { fn ensure_safe_sqlite_threading_mode() -> Result<()> {
// Ensure SQLite was compiled in thredsafe mode. // Ensure SQLite was compiled in thredsafe mode.
if unsafe { ffi::sqlite3_threadsafe() == 0 } { if unsafe { ffi::sqlite3_threadsafe() == 0 } {

View File

@ -10,7 +10,6 @@
//! //!
//! (See [SQLite doc](http://sqlite.org/vtab.html)) //! (See [SQLite doc](http://sqlite.org/vtab.html))
use std::borrow::Cow::{self, Borrowed, Owned}; use std::borrow::Cow::{self, Borrowed, Owned};
use std::ffi::CString;
use std::marker::PhantomData; use std::marker::PhantomData;
use std::marker::Sync; use std::marker::Sync;
use std::os::raw::{c_char, c_int, c_void}; use std::os::raw::{c_char, c_int, c_void};
@ -705,23 +704,23 @@ where
ffi::SQLITE_OK ffi::SQLITE_OK
} else { } else {
let err = error_from_sqlite_code(rc, None); let err = error_from_sqlite_code(rc, None);
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
rc rc
} }
} }
Err(err) => { Err(err) => {
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
ffi::SQLITE_ERROR ffi::SQLITE_ERROR
} }
}, },
Err(Error::SqliteFailure(err, s)) => { Err(Error::SqliteFailure(err, s)) => {
if let Some(s) = s { if let Some(s) = s {
*err_msg = mprintf(&s); *err_msg = alloc(&s);
} }
err.extended_code err.extended_code
} }
Err(err) => { Err(err) => {
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
ffi::SQLITE_ERROR ffi::SQLITE_ERROR
} }
} }
@ -757,23 +756,23 @@ where
ffi::SQLITE_OK ffi::SQLITE_OK
} else { } else {
let err = error_from_sqlite_code(rc, None); let err = error_from_sqlite_code(rc, None);
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
rc rc
} }
} }
Err(err) => { Err(err) => {
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
ffi::SQLITE_ERROR ffi::SQLITE_ERROR
} }
}, },
Err(Error::SqliteFailure(err, s)) => { Err(Error::SqliteFailure(err, s)) => {
if let Some(s) = s { if let Some(s) = s {
*err_msg = mprintf(&s); *err_msg = alloc(&s);
} }
err.extended_code err.extended_code
} }
Err(err) => { Err(err) => {
*err_msg = mprintf(&err.to_string()); *err_msg = alloc(&err.to_string());
ffi::SQLITE_ERROR ffi::SQLITE_ERROR
} }
} }
@ -971,7 +970,7 @@ unsafe fn set_err_msg(vtab: *mut ffi::sqlite3_vtab, err_msg: &str) {
if !(*vtab).zErrMsg.is_null() { if !(*vtab).zErrMsg.is_null() {
ffi::sqlite3_free((*vtab).zErrMsg as *mut c_void); ffi::sqlite3_free((*vtab).zErrMsg as *mut c_void);
} }
(*vtab).zErrMsg = mprintf(err_msg); (*vtab).zErrMsg = alloc(err_msg);
} }
/// To raise an error, the `column` method should use this method to set the /// To raise an error, the `column` method should use this method to set the
@ -1006,12 +1005,25 @@ unsafe fn result_error<T>(ctx: *mut ffi::sqlite3_context, result: Result<T>) ->
} }
} }
// Space to hold this error message string must be obtained // Space to hold this string must be obtained
// from an SQLite memory allocation function. // from an SQLite memory allocation function
fn mprintf(err_msg: &str) -> *mut c_char { unsafe fn alloc<S: AsRef<[u8]>>(s: S) -> *mut c_char {
let c_format = CString::new("%s").unwrap(); use std::convert::TryInto;
let c_err = CString::new(err_msg).unwrap(); let s = s.as_ref();
unsafe { ffi::sqlite3_mprintf(c_format.as_ptr(), c_err.as_ptr()) } if memchr::memchr(0, s).is_some() {
panic!("Null character found")
}
let len = s.len();
let total_len = len.checked_add(1).unwrap();
let dst = ffi::sqlite3_malloc(total_len.try_into().unwrap()) as *mut c_char;
if dst.is_null() {
panic!("Out of memory")
}
ptr::copy_nonoverlapping(s.as_ptr() as *const c_char, dst, len);
// null terminator
*dst.offset(len.try_into().unwrap()) = 0;
dst
} }
#[cfg(feature = "array")] #[cfg(feature = "array")]