diff --git a/Cargo.toml b/Cargo.toml index 5cce914..f13a785 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -99,6 +99,7 @@ fallible-iterator = "0.2" fallible-streaming-iterator = "0.1" memchr = "2.2.0" uuid = { version = "0.8", optional = true } +smallvec = "1.3" [dev-dependencies] doc-comment = "0.3" diff --git a/src/lib.rs b/src/lib.rs index c4860d2..7dc7b03 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -128,6 +128,8 @@ mod version; #[cfg(feature = "vtab")] pub mod vtab; +pub(crate) mod util; + // Number of cached prepared statements we'll hold on to. const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; /// To be used when your statement has no [parameter](https://sqlite.org/lang_expr.html#varparam). @@ -274,7 +276,7 @@ fn path_to_cstring(p: &Path) -> Result { #[cfg(not(unix))] fn path_to_cstring(p: &Path) -> Result { let s = p.to_str().ok_or_else(|| Error::InvalidPath(p.to_owned()))?; - str_to_cstring(s) + Ok(CString::new(s)?) } /// Name for a database within a SQLite connection. diff --git a/src/raw_statement.rs b/src/raw_statement.rs index 05634d6..610dcd1 100644 --- a/src/raw_statement.rs +++ b/src/raw_statement.rs @@ -7,11 +7,11 @@ use std::ptr; // Private newtype for raw sqlite3_stmts that finalize themselves when dropped. #[derive(Debug)] -pub struct RawStatement(*mut ffi::sqlite3_stmt, bool); +pub struct RawStatement(*mut ffi::sqlite3_stmt, bool, crate::util::ParamIndexCache); impl RawStatement { pub unsafe fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement { - RawStatement(stmt, tail) + RawStatement(stmt, tail, Default::default()) } pub fn is_null(&self) -> bool { @@ -87,12 +87,14 @@ impl RawStatement { unsafe { ffi::sqlite3_bind_parameter_count(self.0) as usize } } - pub fn bind_parameter_index(&self, name: &CStr) -> Option { - let r = unsafe { ffi::sqlite3_bind_parameter_index(self.0, name.as_ptr()) }; - match r { - 0 => None, - i => Some(i as usize), - } + pub fn bind_parameter_index(&self, name: &str) -> Option { + self.2.get_or_insert_with(name, |param_cstr| { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.0, param_cstr.as_ptr()) }; + match r { + 0 => None, + i => Some(i as usize), + } + }) } pub fn clear_bindings(&self) -> c_int { diff --git a/src/statement.rs b/src/statement.rs index 235b1b7..b4edfbc 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -6,7 +6,7 @@ use std::slice::from_raw_parts; use std::{convert, fmt, mem, ptr, str}; use super::ffi; -use super::{len_as_c_int, str_for_sqlite, str_to_cstring}; +use super::{len_as_c_int, str_for_sqlite}; use super::{ AndThenRows, Connection, Error, MappedRows, RawStatement, Result, Row, Rows, ValueRef, }; @@ -432,8 +432,7 @@ impl Statement<'_> { /// Will return Err if `name` is invalid. Will return Ok(None) if the name /// is valid but not a bound parameter of this statement. pub fn parameter_index(&self, name: &str) -> Result> { - let c_name = str_to_cstring(name)?; - Ok(self.stmt.bind_parameter_index(&c_name)) + Ok(self.stmt.bind_parameter_index(name)) } fn bind_parameters

(&mut self, params: P) -> Result<()> diff --git a/src/util/mod.rs b/src/util/mod.rs new file mode 100644 index 0000000..d835fcf --- /dev/null +++ b/src/util/mod.rs @@ -0,0 +1,5 @@ +// Internal utilities +pub(crate) mod param_cache; +mod small_cstr; +pub(crate) use param_cache::ParamIndexCache; +pub(crate) use small_cstr::SmallCString; diff --git a/src/util/param_cache.rs b/src/util/param_cache.rs new file mode 100644 index 0000000..6faced9 --- /dev/null +++ b/src/util/param_cache.rs @@ -0,0 +1,60 @@ +use super::SmallCString; +use std::cell::RefCell; +use std::collections::BTreeMap; + +/// Maps parameter names to parameter indices. +#[derive(Default, Clone, Debug)] +// BTreeMap seems to do better here unless we want to pull in a custom hash +// function. +pub(crate) struct ParamIndexCache(RefCell>); + +impl ParamIndexCache { + pub fn get_or_insert_with(&self, s: &str, func: F) -> Option + where + F: FnOnce(&std::ffi::CStr) -> Option, + { + let mut cache = self.0.borrow_mut(); + // Avoid entry API, needs allocation to test membership. + if let Some(v) = cache.get(s) { + return Some(*v); + } + // If there's an internal nul in the name it couldn't have been a + // parameter, so early return here is ok. + let name = SmallCString::new(s).ok()?; + let val = func(&name)?; + cache.insert(name, val); + Some(val) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn test_cache() { + let p = ParamIndexCache::default(); + let v = p.get_or_insert_with("foo", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "foo"); + Some(3) + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("foo", |_| { + panic!("shouldn't be called this time"); + }); + assert_eq!(v, Some(3)); + let v = p.get_or_insert_with("gar\0bage", |_| { + panic!("shouldn't be called here either"); + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + None + }); + assert_eq!(v, None); + let v = p.get_or_insert_with("bar", |cstr| { + assert_eq!(cstr.to_str().unwrap(), "bar"); + Some(30) + }); + assert_eq!(v, Some(30)); + } +} diff --git a/src/util/small_cstr.rs b/src/util/small_cstr.rs new file mode 100644 index 0000000..fae52f0 --- /dev/null +++ b/src/util/small_cstr.rs @@ -0,0 +1,132 @@ +use smallvec::{smallvec, SmallVec}; +use std::ffi::{CStr, CString, NulError}; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct SmallCString(smallvec::SmallVec<[u8; 16]>); + +impl SmallCString { + #[inline] + pub fn new(s: &str) -> Result { + if s.as_bytes().contains(&0u8) { + return Err(Self::fabricate_nul_error(s)); + } + let mut buf = SmallVec::with_capacity(s.len() + 1); + buf.extend_from_slice(s.as_bytes()); + buf.push(0); + Ok(Self(buf)) + } + + #[inline] + pub fn as_str(&self) -> &str { + debug_assert!(std::str::from_utf8(&self.as_bytes_without_nul()).is_ok()); + // Constructor takes a &str so this is safe. + unsafe { std::str::from_utf8_unchecked(&self.as_bytes_without_nul()) } + } + + #[inline] + pub fn as_bytes_without_nul(&self) -> &[u8] { + &self.0[..self.0.len() - 1] + } + + #[inline] + pub fn len(&self) -> usize { + debug_assert_ne!(self.0.len(), 0); + self.0.len() - 1 + } + + #[inline] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline] + pub fn as_cstr(&self) -> &CStr { + debug_assert!(CStr::from_bytes_with_nul(&self.0).is_ok()); + unsafe { CStr::from_bytes_with_nul_unchecked(&self.0) } + } + + #[cold] + fn fabricate_nul_error(b: &str) -> NulError { + CString::new(b).unwrap_err() + } +} + +impl Default for SmallCString { + #[inline] + fn default() -> Self { + Self(smallvec![0]) + } +} + +impl std::fmt::Debug for SmallCString { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_tuple("SmallCString").field(&self.as_str()).finish() + } +} + +impl std::ops::Deref for SmallCString { + type Target = CStr; + #[inline] + fn deref(&self) -> &CStr { + self.as_cstr() + } +} + +impl PartialEq for str { + #[inline] + fn eq(&self, s: &SmallCString) -> bool { + s.as_bytes_without_nul() == self.as_bytes() + } +} + +impl PartialEq for SmallCString { + #[inline] + fn eq(&self, s: &str) -> bool { + self.as_bytes_without_nul() == s.as_bytes() + } +} + +impl std::borrow::Borrow for SmallCString { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_small_cstring() { + // We don't go through the normal machinery for default, so make sure + // things work. + assert_eq!(SmallCString::default().0, SmallCString::new("").unwrap().0); + assert_eq!(SmallCString::new("foo").unwrap().len(), 3); + assert_eq!(SmallCString::new("foo").unwrap().0.as_slice(), b"foo\0"); + assert_eq!( + SmallCString::new("foo").unwrap().as_bytes_without_nul(), + b"foo", + ); + + assert_eq!(SmallCString::new("😀").unwrap().len(), 4); + assert_eq!( + SmallCString::new("😀").unwrap().0.as_slice(), + b"\xf0\x9f\x98\x80\0", + ); + assert_eq!( + SmallCString::new("😀").unwrap().as_bytes_without_nul(), + b"\xf0\x9f\x98\x80", + ); + + assert_eq!(SmallCString::new("").unwrap().len(), 0); + assert!(SmallCString::new("").unwrap().is_empty()); + + assert_eq!(SmallCString::new("").unwrap().0.as_slice(), b"\0"); + assert_eq!(SmallCString::new("").unwrap().as_bytes_without_nul(), b""); + + assert!(SmallCString::new("\0").is_err()); + assert!(SmallCString::new("\0abc").is_err()); + assert!(SmallCString::new("abc\0").is_err()); + } +}