Merge remote-tracking branch 'upstream/master' into pr/open-with-vfs

# Conflicts:
#	src/inner_connection.rs
This commit is contained in:
gwenn
2020-02-09 11:58:49 +01:00
43 changed files with 14816 additions and 7924 deletions

View File

@@ -1,4 +1,5 @@
///! Busy handler (when the database is locked)
use std::convert::TryInto;
use std::mem;
use std::os::raw::{c_int, c_void};
use std::panic::catch_unwind;
@@ -21,12 +22,13 @@ impl Connection {
/// (using `busy_handler`) prior to calling this routine, that other
/// busy handler is cleared.
pub fn busy_timeout(&self, timeout: Duration) -> Result<()> {
let ms = timeout
let ms: i32 = timeout
.as_secs()
.checked_mul(1000)
.and_then(|t| t.checked_add(timeout.subsec_millis().into()))
.and_then(|t| t.try_into().ok())
.expect("too big");
self.db.borrow_mut().busy_timeout(ms as i32)
self.db.borrow_mut().busy_timeout(ms)
}
/// Register a callback to handle `SQLITE_BUSY` errors.
@@ -75,18 +77,17 @@ impl InnerConnection {
#[cfg(test)]
mod test {
use self::tempdir::TempDir;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::sync_channel;
use std::thread;
use std::time::Duration;
use tempdir;
use tempfile;
use crate::{Connection, Error, ErrorCode, Result, TransactionBehavior, NO_PARAMS};
#[test]
fn test_default_busy() {
let temp_dir = TempDir::new("test_default_busy").unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.db3");
let mut db1 = Connection::open(&path).unwrap();
@@ -107,7 +108,7 @@ mod test {
#[test]
#[ignore] // FIXME: unstable
fn test_busy_timeout() {
let temp_dir = TempDir::new("test_busy_timeout").unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.db3");
let db2 = Connection::open(&path).unwrap();
@@ -137,7 +138,7 @@ mod test {
#[test]
#[ignore] // FIXME: unstable
fn test_busy_handler() {
lazy_static! {
lazy_static::lazy_static! {
static ref CALLED: AtomicBool = AtomicBool::new(false);
}
fn busy_handler(_: i32) -> bool {
@@ -146,7 +147,7 @@ mod test {
true
}
let temp_dir = TempDir::new("test_busy_handler").unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.db3");
let db2 = Connection::open(&path).unwrap();

View File

@@ -135,9 +135,12 @@ impl StatementCache {
// Return a statement to the cache.
fn cache_stmt(&self, stmt: RawStatement) {
if stmt.is_null() {
return;
}
let mut cache = self.0.borrow_mut();
stmt.clear_bindings();
let sql = String::from_utf8_lossy(stmt.sql().to_bytes())
let sql = String::from_utf8_lossy(stmt.sql().unwrap().to_bytes())
.trim()
.to_string();
cache.insert(sql, stmt);
@@ -339,4 +342,10 @@ mod test {
}
assert_eq!(1, cache.len());
}
#[test]
fn test_empty_stmt() {
let conn = Connection::open_in_memory().unwrap();
conn.prepare_cached("").unwrap();
}
}

209
src/collation.rs Normal file
View File

@@ -0,0 +1,209 @@
//! Add, remove, or modify a collation
use std::cmp::Ordering;
use std::os::raw::{c_char, c_int, c_void};
use std::panic::{catch_unwind, UnwindSafe};
use std::ptr;
use std::slice;
use crate::ffi;
use crate::{str_to_cstring, Connection, InnerConnection, Result};
// FIXME copy/paste from function.rs
unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
drop(Box::from_raw(p as *mut T));
}
impl Connection {
/// Add or modify a collation.
pub fn create_collation<C>(&self, collation_name: &str, x_compare: C) -> Result<()>
where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
{
self.db
.borrow_mut()
.create_collation(collation_name, x_compare)
}
/// Collation needed callback
pub fn collation_needed(
&self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
self.db.borrow_mut().collation_needed(x_coll_needed)
}
/// Remove collation.
pub fn remove_collation(&self, collation_name: &str) -> Result<()> {
self.db.borrow_mut().remove_collation(collation_name)
}
}
impl InnerConnection {
fn create_collation<C>(&mut self, collation_name: &str, x_compare: C) -> Result<()>
where
C: Fn(&str, &str) -> Ordering + Send + UnwindSafe + 'static,
{
unsafe extern "C" fn call_boxed_closure<C>(
arg1: *mut c_void,
arg2: c_int,
arg3: *const c_void,
arg4: c_int,
arg5: *const c_void,
) -> c_int
where
C: Fn(&str, &str) -> Ordering,
{
use std::str;
let r = catch_unwind(|| {
let boxed_f: *mut C = arg1 as *mut C;
assert!(!boxed_f.is_null(), "Internal error - null function pointer");
let s1 = {
let c_slice = slice::from_raw_parts(arg3 as *const u8, arg2 as usize);
str::from_utf8_unchecked(c_slice)
};
let s2 = {
let c_slice = slice::from_raw_parts(arg5 as *const u8, arg4 as usize);
str::from_utf8_unchecked(c_slice)
};
(*boxed_f)(s1, s2)
});
let t = match r {
Err(_) => {
return -1; // FIXME How ?
}
Ok(r) => r,
};
match t {
Ordering::Less => -1,
Ordering::Equal => 0,
Ordering::Greater => 1,
}
}
let boxed_f: *mut C = Box::into_raw(Box::new(x_compare));
let c_name = str_to_cstring(collation_name)?;
let flags = ffi::SQLITE_UTF8;
let r = unsafe {
ffi::sqlite3_create_collation_v2(
self.db(),
c_name.as_ptr(),
flags,
boxed_f as *mut c_void,
Some(call_boxed_closure::<C>),
Some(free_boxed_value::<C>),
)
};
self.decode_result(r)
}
fn collation_needed(
&mut self,
x_coll_needed: fn(&Connection, &str) -> Result<()>,
) -> Result<()> {
use std::mem;
unsafe extern "C" fn collation_needed_callback(
arg1: *mut c_void,
arg2: *mut ffi::sqlite3,
e_text_rep: c_int,
arg3: *const c_char,
) {
use std::ffi::CStr;
use std::str;
if e_text_rep != ffi::SQLITE_UTF8 {
// TODO: validate
return;
}
let callback: fn(&Connection, &str) -> Result<()> = mem::transmute(arg1);
if catch_unwind(|| {
let conn = Connection::from_handle(arg2).unwrap();
let collation_name = {
let c_slice = CStr::from_ptr(arg3).to_bytes();
str::from_utf8_unchecked(c_slice)
};
callback(&conn, collation_name)
})
.is_err()
{
return; // FIXME How ?
}
}
let r = unsafe {
ffi::sqlite3_collation_needed(
self.db(),
mem::transmute(x_coll_needed),
Some(collation_needed_callback),
)
};
self.decode_result(r)
}
fn remove_collation(&mut self, collation_name: &str) -> Result<()> {
let c_name = str_to_cstring(collation_name)?;
let r = unsafe {
ffi::sqlite3_create_collation_v2(
self.db(),
c_name.as_ptr(),
ffi::SQLITE_UTF8,
ptr::null_mut(),
None,
None,
)
};
self.decode_result(r)
}
}
#[cfg(test)]
mod test {
use crate::{Connection, Result, NO_PARAMS};
use fallible_streaming_iterator::FallibleStreamingIterator;
use std::cmp::Ordering;
use unicase::UniCase;
fn unicase_compare(s1: &str, s2: &str) -> Ordering {
UniCase::new(s1).cmp(&UniCase::new(s2))
}
#[test]
fn test_unicase() {
let db = Connection::open_in_memory().unwrap();
db.create_collation("unicase", unicase_compare).unwrap();
collate(db);
}
fn collate(db: Connection) {
db.execute_batch(
"CREATE TABLE foo (bar);
INSERT INTO foo (bar) VALUES ('Maße');
INSERT INTO foo (bar) VALUES ('MASSE');",
)
.unwrap();
let mut stmt = db
.prepare("SELECT DISTINCT bar COLLATE unicase FROM foo ORDER BY 1")
.unwrap();
let rows = stmt.query(NO_PARAMS).unwrap();
assert_eq!(rows.count().unwrap(), 1);
}
fn collation_needed(db: &Connection, collation_name: &str) -> Result<()> {
if "unicase" == collation_name {
db.create_collation(collation_name, unicase_compare)
} else {
Ok(())
}
}
#[test]
fn test_collation_needed() {
let db = Connection::open_in_memory().unwrap();
db.collation_needed(collation_needed).unwrap();
collate(db);
}
}

View File

@@ -27,8 +27,7 @@ impl Statement<'_> {
let n = self.column_count();
let mut cols = Vec::with_capacity(n as usize);
for i in 0..n {
let slice = self.stmt.column_name(i);
let s = str::from_utf8(slice.to_bytes()).unwrap();
let s = self.column_name_unwrap(i);
cols.push(s);
}
cols
@@ -40,6 +39,30 @@ impl Statement<'_> {
self.stmt.column_count()
}
pub(crate) fn column_name_unwrap(&self, col: usize) -> &str {
// Just panic if the bounds are wrong for now, we never call this
// without checking first.
self.column_name(col).expect("Column out of bounds")
}
/// Returns the name assigned to a particular column in the result set
/// returned by the prepared statement.
///
/// ## Failure
///
/// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid
/// column range for this row.
///
/// Panics when column name is not valid UTF-8.
pub fn column_name(&self, col: usize) -> Result<&str> {
self.stmt
.column_name(col)
.ok_or(Error::InvalidColumnIndex(col))
.map(|slice| {
str::from_utf8(slice.to_bytes()).expect("Invalid UTF-8 sequence in column name")
})
}
/// Returns the column index in the result set for a given column name.
///
/// If there is no AS clause then the name of the column is unspecified and
@@ -53,7 +76,9 @@ impl Statement<'_> {
let bytes = name.as_bytes();
let n = self.column_count();
for i in 0..n {
if bytes.eq_ignore_ascii_case(self.stmt.column_name(i).to_bytes()) {
// Note: `column_name` is only fallible if `i` is out of bounds,
// which we've already checked.
if bytes.eq_ignore_ascii_case(self.stmt.column_name(i).unwrap().to_bytes()) {
return Ok(i);
}
}
@@ -61,14 +86,15 @@ impl Statement<'_> {
}
/// Returns a slice describing the columns of the result of the query.
pub fn columns<'stmt>(&'stmt self) -> Vec<Column<'stmt>> {
pub fn columns(&self) -> Vec<Column> {
let n = self.column_count();
let mut cols = Vec::with_capacity(n as usize);
for i in 0..n {
let slice = self.stmt.column_name(i);
let name = str::from_utf8(slice.to_bytes()).unwrap();
let name = self.column_name_unwrap(i);
let slice = self.stmt.column_decltype(i);
let decl_type = slice.map(|s| str::from_utf8(s.to_bytes()).unwrap());
let decl_type = slice.map(|s| {
str::from_utf8(s.to_bytes()).expect("Invalid UTF-8 sequence in column declaration")
});
cols.push(Column { name, decl_type });
}
cols
@@ -86,20 +112,45 @@ impl<'stmt> Rows<'stmt> {
self.stmt.map(Statement::column_count)
}
/// Return the name of the column.
pub fn column_name(&self, col: usize) -> Option<Result<&str>> {
self.stmt.map(|stmt| stmt.column_name(col))
}
/// Return the index of the column.
pub fn column_index(&self, name: &str) -> Option<Result<usize>> {
self.stmt.map(|stmt| stmt.column_index(name))
}
/// Returns a slice describing the columns of the Rows.
pub fn columns(&self) -> Option<Vec<Column<'stmt>>> {
pub fn columns(&self) -> Option<Vec<Column>> {
self.stmt.map(Statement::columns)
}
}
impl<'stmt> Row<'stmt> {
/// Get all the column names of the Row.
pub fn column_names(&self) -> Vec<&str> {
self.stmt.column_names()
}
/// Return the number of columns in the current row.
pub fn column_count(&self) -> usize {
self.stmt.column_count()
}
/// Return the name of the column.
pub fn column_name(&self, col: usize) -> Result<&str> {
self.stmt.column_name(col)
}
/// Return the index of the column.
pub fn column_index(&self, name: &str) -> Result<usize> {
self.stmt.column_index(name)
}
/// Returns a slice describing the columns of the Row.
pub fn columns(&self) -> Vec<Column<'stmt>> {
pub fn columns(&self) -> Vec<Column> {
self.stmt.columns()
}
}
@@ -125,4 +176,40 @@ mod test {
&[Some("text"), Some("text"), Some("text"),]
);
}
#[test]
fn test_column_name_in_error() {
use crate::{types::Type, Error};
let db = Connection::open_in_memory().unwrap();
db.execute_batch(
"BEGIN;
CREATE TABLE foo(x INTEGER, y TEXT);
INSERT INTO foo VALUES(4, NULL);
END;",
)
.unwrap();
let mut stmt = db.prepare("SELECT x as renamed, y FROM foo").unwrap();
let mut rows = stmt.query(crate::NO_PARAMS).unwrap();
let row = rows.next().unwrap().unwrap();
match row.get::<_, String>(0).unwrap_err() {
Error::InvalidColumnType(idx, name, ty) => {
assert_eq!(idx, 0);
assert_eq!(name, "renamed");
assert_eq!(ty, Type::Integer);
}
e => {
panic!("Unexpected error type: {:?}", e);
}
}
match row.get::<_, String>("y").unwrap_err() {
Error::InvalidColumnType(idx, name, ty) => {
assert_eq!(idx, 1);
assert_eq!(name, "y");
assert_eq!(ty, Type::Null);
}
e => {
panic!("Unexpected error type: {:?}", e);
}
}
}
}

View File

@@ -15,11 +15,19 @@ pub enum DbConfig {
SQLITE_DBCONFIG_ENABLE_TRIGGER = 1003,
SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER = 1004, // 3.12.0
//SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION = 1005,
SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006,
SQLITE_DBCONFIG_ENABLE_QPSG = 1007, // 3.20.0
SQLITE_DBCONFIG_TRIGGER_EQP = 1008,
SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE = 1006, // 3.16.2
SQLITE_DBCONFIG_ENABLE_QPSG = 1007, // 3.20.0
SQLITE_DBCONFIG_TRIGGER_EQP = 1008, // 3.22.0
//SQLITE_DBCONFIG_RESET_DATABASE = 1009,
SQLITE_DBCONFIG_DEFENSIVE = 1010,
SQLITE_DBCONFIG_DEFENSIVE = 1010, // 3.26.0
#[cfg(feature = "modern_sqlite")]
SQLITE_DBCONFIG_WRITABLE_SCHEMA = 1011, // 3.28.0
#[cfg(feature = "modern_sqlite")]
SQLITE_DBCONFIG_LEGACY_ALTER_TABLE = 1012, // 3.29
#[cfg(feature = "modern_sqlite")]
SQLITE_DBCONFIG_DQS_DML = 1013, // 3.29.0
#[cfg(feature = "modern_sqlite")]
SQLITE_DBCONFIG_DQS_DDL = 1014, // 3.29.0
}
impl Connection {

View File

@@ -1,3 +1,4 @@
use crate::types::FromSqlError;
use crate::types::Type;
use crate::{errmsg_to_string, ffi};
use std::error;
@@ -59,7 +60,7 @@ pub enum Error {
/// Error when the value of a particular column is requested, but the type
/// of the result in that column cannot be converted to the requested
/// Rust type.
InvalidColumnType(usize, Type),
InvalidColumnType(usize, String, Type),
/// Error when a query that was expected to insert one row did not insert
/// any or insert many.
@@ -99,6 +100,9 @@ pub enum Error {
/// of a different type than what had been stored using `Context::set_aux`.
#[cfg(feature = "functions")]
GetAuxWrongType,
/// Error when the SQL contains multiple statements.
MultipleStatement,
}
impl PartialEq for Error {
@@ -117,8 +121,8 @@ impl PartialEq for Error {
(Error::QueryReturnedNoRows, Error::QueryReturnedNoRows) => true,
(Error::InvalidColumnIndex(i1), Error::InvalidColumnIndex(i2)) => i1 == i2,
(Error::InvalidColumnName(n1), Error::InvalidColumnName(n2)) => n1 == n2,
(Error::InvalidColumnType(i1, t1), Error::InvalidColumnType(i2, t2)) => {
i1 == i2 && t1 == t2
(Error::InvalidColumnType(i1, n1, t1), Error::InvalidColumnType(i2, n2, t2)) => {
i1 == i2 && t1 == t2 && n1 == n2
}
(Error::StatementChangedRows(n1), Error::StatementChangedRows(n2)) => n1 == n2,
#[cfg(feature = "functions")]
@@ -138,7 +142,7 @@ impl PartialEq for Error {
(Error::UnwindingPanic, Error::UnwindingPanic) => true,
#[cfg(feature = "functions")]
(Error::GetAuxWrongType, Error::GetAuxWrongType) => true,
(_, _) => false,
(..) => false,
}
}
}
@@ -155,6 +159,32 @@ impl From<::std::ffi::NulError> for Error {
}
}
const UNKNOWN_COLUMN: usize = std::usize::MAX;
/// The conversion isn't precise, but it's convenient to have it
/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`.
impl From<FromSqlError> for Error {
fn from(err: FromSqlError) -> Error {
// The error type requires index and type fields, but they aren't known in this
// context.
match err {
FromSqlError::OutOfRange(val) => Error::IntegralValueOutOfRange(UNKNOWN_COLUMN, val),
#[cfg(feature = "i128_blob")]
FromSqlError::InvalidI128Size(_) => {
Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err))
}
#[cfg(feature = "uuid")]
FromSqlError::InvalidUuidSize(_) => {
Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Blob, Box::new(err))
}
FromSqlError::Other(source) => {
Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, source)
}
_ => Error::FromSqlConversionFailure(UNKNOWN_COLUMN, Type::Null, Box::new(err)),
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
@@ -164,13 +194,23 @@ impl fmt::Display for Error {
f,
"SQLite was compiled or configured for single-threaded use only"
),
Error::FromSqlConversionFailure(i, ref t, ref err) => write!(
f,
"Conversion error from type {} at index: {}, {}",
t, i, err
),
Error::FromSqlConversionFailure(i, ref t, ref err) => {
if i != UNKNOWN_COLUMN {
write!(
f,
"Conversion error from type {} at index: {}, {}",
t, i, err
)
} else {
err.fmt(f)
}
}
Error::IntegralValueOutOfRange(col, val) => {
write!(f, "Integer {} out of range at index {}", val, col)
if col != UNKNOWN_COLUMN {
write!(f, "Integer {} out of range at index {}", val, col)
} else {
write!(f, "Integer {} out of range", val)
}
}
Error::Utf8Error(ref err) => err.fmt(f),
Error::NulError(ref err) => err.fmt(f),
@@ -182,9 +222,11 @@ impl fmt::Display for Error {
Error::QueryReturnedNoRows => write!(f, "Query returned no rows"),
Error::InvalidColumnIndex(i) => write!(f, "Invalid column index: {}", i),
Error::InvalidColumnName(ref name) => write!(f, "Invalid column name: {}", name),
Error::InvalidColumnType(i, ref t) => {
write!(f, "Invalid column type {} at index: {}", t, i)
}
Error::InvalidColumnType(i, ref name, ref t) => write!(
f,
"Invalid column type {} at index: {}, name: {}",
t, i, name
),
Error::StatementChangedRows(i) => write!(f, "Query changed {} rows", i),
#[cfg(feature = "functions")]
@@ -205,11 +247,13 @@ impl fmt::Display for Error {
Error::UnwindingPanic => write!(f, "unwinding panic"),
#[cfg(feature = "functions")]
Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"),
Error::MultipleStatement => write!(f, "Multiple statements provided"),
}
}
}
impl error::Error for Error {
#[allow(deprecated)]
fn description(&self) -> &str {
match *self {
Error::SqliteFailure(ref err, None) => err.description(),
@@ -218,7 +262,7 @@ impl error::Error for Error {
"SQLite was compiled or configured for single-threaded use only"
}
Error::FromSqlConversionFailure(_, _, ref err) => err.description(),
Error::IntegralValueOutOfRange(_, _) => "integral value out of range of requested type",
Error::IntegralValueOutOfRange(..) => "integral value out of range of requested type",
Error::Utf8Error(ref err) => err.description(),
Error::InvalidParameterName(_) => "invalid parameter name",
Error::NulError(ref err) => err.description(),
@@ -229,13 +273,13 @@ impl error::Error for Error {
Error::QueryReturnedNoRows => "query returned no rows",
Error::InvalidColumnIndex(_) => "invalid column index",
Error::InvalidColumnName(_) => "invalid column name",
Error::InvalidColumnType(_, _) => "invalid column type",
Error::InvalidColumnType(..) => "invalid column type",
Error::StatementChangedRows(_) => "query inserted zero or more than one row",
#[cfg(feature = "functions")]
Error::InvalidFunctionParameterType(_, _) => "invalid function parameter type",
Error::InvalidFunctionParameterType(..) => "invalid function parameter type",
#[cfg(feature = "vtab")]
Error::InvalidFilterParameterType(_, _) => "invalid filter parameter type",
Error::InvalidFilterParameterType(..) => "invalid filter parameter type",
#[cfg(feature = "functions")]
Error::UserFunctionError(ref err) => err.description(),
Error::ToSqlConversionFailure(ref err) => err.description(),
@@ -246,6 +290,7 @@ impl error::Error for Error {
Error::UnwindingPanic => "unwinding panic",
#[cfg(feature = "functions")]
Error::GetAuxWrongType => "get_aux called with wrong type",
Error::MultipleStatement => "multiple statements provided",
}
}
@@ -255,22 +300,23 @@ impl error::Error for Error {
Error::Utf8Error(ref err) => Some(err),
Error::NulError(ref err) => Some(err),
Error::IntegralValueOutOfRange(_, _)
Error::IntegralValueOutOfRange(..)
| Error::SqliteSingleThreadedMode
| Error::InvalidParameterName(_)
| Error::ExecuteReturnedResults
| Error::QueryReturnedNoRows
| Error::InvalidColumnIndex(_)
| Error::InvalidColumnName(_)
| Error::InvalidColumnType(_, _)
| Error::InvalidColumnType(..)
| Error::InvalidPath(_)
| Error::StatementChangedRows(_)
| Error::InvalidQuery => None,
| Error::InvalidQuery
| Error::MultipleStatement => None,
#[cfg(feature = "functions")]
Error::InvalidFunctionParameterType(_, _) => None,
Error::InvalidFunctionParameterType(..) => None,
#[cfg(feature = "vtab")]
Error::InvalidFilterParameterType(_, _) => None,
Error::InvalidFilterParameterType(..) => None,
#[cfg(feature = "functions")]
Error::UserFunctionError(ref err) => Some(&**err),
@@ -309,7 +355,7 @@ macro_rules! check {
($funcall:expr) => {{
let rc = $funcall;
if rc != crate::ffi::SQLITE_OK {
Err(crate::error::error_from_sqlite_code(rc, None))?;
return Err(crate::error::error_from_sqlite_code(rc, None).into());
}
}};
}

View File

@@ -10,28 +10,47 @@
//!
//! ```rust
//! use regex::Regex;
//! use rusqlite::functions::FunctionFlags;
//! use rusqlite::{Connection, Error, Result, NO_PARAMS};
//! use std::collections::HashMap;
//!
//! fn add_regexp_function(db: &Connection) -> Result<()> {
//! let mut cached_regexes = HashMap::new();
//! db.create_scalar_function("regexp", 2, true, move |ctx| {
//! let regex_s = ctx.get::<String>(0)?;
//! let entry = cached_regexes.entry(regex_s.clone());
//! let regex = {
//! use std::collections::hash_map::Entry::{Occupied, Vacant};
//! match entry {
//! Occupied(occ) => occ.into_mut(),
//! Vacant(vac) => match Regex::new(&regex_s) {
//! Ok(r) => vac.insert(r),
//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
//! },
//! }
//! };
//! db.create_scalar_function(
//! "regexp",
//! 2,
//! FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
//! move |ctx| {
//! assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
//!
//! let text = ctx.get::<String>(1)?;
//! Ok(regex.is_match(&text))
//! })
//! let saved_re: Option<&Regex> = ctx.get_aux(0)?;
//! let new_re = match saved_re {
//! None => {
//! let s = ctx.get::<String>(0)?;
//! match Regex::new(&s) {
//! Ok(r) => Some(r),
//! Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
//! }
//! }
//! Some(_) => None,
//! };
//!
//! let is_match = {
//! let re = saved_re.unwrap_or_else(|| new_re.as_ref().unwrap());
//!
//! let text = ctx
//! .get_raw(1)
//! .as_str()
//! .map_err(|e| Error::UserFunctionError(e.into()))?;
//!
//! re.is_match(text)
//! };
//!
//! if let Some(re) = new_re {
//! ctx.set_aux(0, re);
//! }
//!
//! Ok(is_match)
//! },
//! )
//! }
//!
//! fn main() -> Result<()> {
@@ -48,7 +67,6 @@
//! Ok(())
//! }
//! ```
use std::error::Error as StdError;
use std::os::raw::{c_int, c_void};
use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
use std::ptr;
@@ -68,11 +86,11 @@ unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
// an explicit feature check for that, and this doesn't really warrant one.
// We'll use the extended code if we're on the bundled version (since it's
// at least 3.17.0) and the normal constraint error code if not.
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn constraint_error_code() -> i32 {
ffi::SQLITE_CONSTRAINT_FUNCTION
}
#[cfg(not(feature = "bundled"))]
#[cfg(not(feature = "modern_sqlite"))]
fn constraint_error_code() -> i32 {
ffi::SQLITE_CONSTRAINT
}
@@ -86,7 +104,7 @@ unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
}
_ => {
ffi::sqlite3_result_error_code(ctx, constraint_error_code());
if let Ok(cstr) = str_to_cstring(err.description()) {
if let Ok(cstr) = str_to_cstring(&err.to_string()) {
ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
}
}
@@ -213,6 +231,44 @@ where
fn finalize(&self, _: Option<A>) -> Result<T>;
}
/// WindowAggregate is the callback interface for user-defined aggregate window
/// function.
#[cfg(feature = "window")]
pub trait WindowAggregate<A, T>: Aggregate<A, T>
where
A: RefUnwindSafe + UnwindSafe,
T: ToSql,
{
/// Returns the current value of the aggregate. Unlike xFinal, the
/// implementation should not delete any context.
fn value(&self, _: Option<&A>) -> Result<T>;
/// Removes a row from the current window.
fn inverse(&self, _: &mut Context<'_>, _: &mut A) -> Result<()>;
}
bitflags::bitflags! {
#[doc = "Function Flags."]
#[doc = "See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html) for details."]
#[repr(C)]
pub struct FunctionFlags: ::std::os::raw::c_int {
const SQLITE_UTF8 = ffi::SQLITE_UTF8;
const SQLITE_UTF16LE = ffi::SQLITE_UTF16LE;
const SQLITE_UTF16BE = ffi::SQLITE_UTF16BE;
const SQLITE_UTF16 = ffi::SQLITE_UTF16;
const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC;
const SQLITE_DIRECTONLY = 0x0000_0008_0000; // 3.30.0
const SQLITE_SUBTYPE = 0x0000_0010_0000; // 3.30.0
const SQLITE_INNOCUOUS = 0x0000_0020_0000; // 3.31.0
}
}
impl Default for FunctionFlags {
fn default() -> FunctionFlags {
FunctionFlags::SQLITE_UTF8
}
}
impl Connection {
/// Attach a user-defined scalar function to this database connection.
///
@@ -228,11 +284,17 @@ impl Connection {
///
/// ```rust
/// # use rusqlite::{Connection, Result, NO_PARAMS};
/// # use rusqlite::functions::FunctionFlags;
/// fn scalar_function_example(db: Connection) -> Result<()> {
/// db.create_scalar_function("halve", 1, true, |ctx| {
/// let value = ctx.get::<f64>(0)?;
/// Ok(value / 2f64)
/// })?;
/// db.create_scalar_function(
/// "halve",
/// 1,
/// FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
/// |ctx| {
/// let value = ctx.get::<f64>(0)?;
/// Ok(value / 2f64)
/// },
/// )?;
///
/// let six_halved: f64 = db.query_row("SELECT halve(6)", NO_PARAMS, |r| r.get(0))?;
/// assert_eq!(six_halved, 3f64);
@@ -247,7 +309,7 @@ impl Connection {
&self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
flags: FunctionFlags,
x_func: F,
) -> Result<()>
where
@@ -256,7 +318,7 @@ impl Connection {
{
self.db
.borrow_mut()
.create_scalar_function(fn_name, n_arg, deterministic, x_func)
.create_scalar_function(fn_name, n_arg, flags, x_func)
}
/// Attach a user-defined aggregate function to this database connection.
@@ -268,7 +330,7 @@ impl Connection {
&self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
flags: FunctionFlags,
aggr: D,
) -> Result<()>
where
@@ -278,7 +340,25 @@ impl Connection {
{
self.db
.borrow_mut()
.create_aggregate_function(fn_name, n_arg, deterministic, aggr)
.create_aggregate_function(fn_name, n_arg, flags, aggr)
}
#[cfg(feature = "window")]
pub fn create_window_function<A, W, T>(
&self,
fn_name: &str,
n_arg: c_int,
flags: FunctionFlags,
aggr: W,
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
self.db
.borrow_mut()
.create_window_function(fn_name, n_arg, flags, aggr)
}
/// Removes a user-defined function from this database connection.
@@ -299,7 +379,7 @@ impl InnerConnection {
&mut self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
flags: FunctionFlags,
x_func: F,
) -> Result<()>
where
@@ -341,16 +421,12 @@ impl InnerConnection {
let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe {
ffi::sqlite3_create_function_v2(
self.db(),
c_name.as_ptr(),
n_arg,
flags,
flags.bits(),
boxed_f as *mut c_void,
Some(call_boxed_closure::<F, T>),
None,
@@ -365,7 +441,7 @@ impl InnerConnection {
&mut self,
fn_name: &str,
n_arg: c_int,
deterministic: bool,
flags: FunctionFlags,
aggr: D,
) -> Result<()>
where
@@ -373,117 +449,14 @@ impl InnerConnection {
D: Aggregate<A, T>,
T: ToSql,
{
unsafe fn aggregate_context<A>(
ctx: *mut sqlite3_context,
bytes: usize,
) -> Option<*mut *mut A> {
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
if pac.is_null() {
return None;
}
Some(pac)
}
unsafe extern "C" fn call_boxed_step<A, D, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
{
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
Some(pac) => pac,
None => {
ffi::sqlite3_result_error_nomem(ctx);
return;
}
};
let r = catch_unwind(|| {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
if (*pac as *mut A).is_null() {
*pac = Box::into_raw(Box::new((*boxed_aggr).init()));
}
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
(*boxed_aggr).step(&mut ctx, &mut **pac)
});
let r = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
match r {
Ok(_) => {}
Err(err) => report_error(ctx, &err),
};
}
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
{
// Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a: Option<A> = match aggregate_context(ctx, 0) {
Some(pac) => {
if (*pac as *mut A).is_null() {
None
} else {
let a = Box::from_raw(*pac);
Some(*a)
}
}
None => None,
};
let r = catch_unwind(|| {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
(*boxed_aggr).finalize(a)
});
let t = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
}
let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
let mut flags = ffi::SQLITE_UTF8;
if deterministic {
flags |= ffi::SQLITE_DETERMINISTIC;
}
let r = unsafe {
ffi::sqlite3_create_function_v2(
self.db(),
c_name.as_ptr(),
n_arg,
flags,
flags.bits(),
boxed_aggr as *mut c_void,
None,
Some(call_boxed_step::<A, D, T>),
@@ -494,6 +467,38 @@ impl InnerConnection {
self.decode_result(r)
}
#[cfg(feature = "window")]
fn create_window_function<A, W, T>(
&mut self,
fn_name: &str,
n_arg: c_int,
flags: FunctionFlags,
aggr: W,
) -> Result<()>
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
let c_name = str_to_cstring(fn_name)?;
let r = unsafe {
ffi::sqlite3_create_window_function(
self.db(),
c_name.as_ptr(),
n_arg,
flags.bits(),
boxed_aggr as *mut c_void,
Some(call_boxed_step::<A, W, T>),
Some(call_boxed_final::<A, W, T>),
Some(call_boxed_value::<A, W, T>),
Some(call_boxed_inverse::<A, W, T>),
Some(free_boxed_value::<W>),
)
};
self.decode_result(r)
}
fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
let c_name = str_to_cstring(fn_name)?;
let r = unsafe {
@@ -513,16 +518,198 @@ impl InnerConnection {
}
}
unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
if pac.is_null() {
return None;
}
Some(pac)
}
unsafe extern "C" fn call_boxed_step<A, D, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
{
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
Some(pac) => pac,
None => {
ffi::sqlite3_result_error_nomem(ctx);
return;
}
};
let r = catch_unwind(|| {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
if (*pac as *mut A).is_null() {
*pac = Box::into_raw(Box::new((*boxed_aggr).init()));
}
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
(*boxed_aggr).step(&mut ctx, &mut **pac)
});
let r = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
match r {
Ok(_) => {}
Err(err) => report_error(ctx, &err),
};
}
#[cfg(feature = "window")]
unsafe extern "C" fn call_boxed_inverse<A, W, T>(
ctx: *mut sqlite3_context,
argc: c_int,
argv: *mut *mut sqlite3_value,
) where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
let pac = match aggregate_context(ctx, ::std::mem::size_of::<*mut A>()) {
Some(pac) => pac,
None => {
ffi::sqlite3_result_error_nomem(ctx);
return;
}
};
let r = catch_unwind(|| {
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
let mut ctx = Context {
ctx,
args: slice::from_raw_parts(argv, argc as usize),
};
(*boxed_aggr).inverse(&mut ctx, &mut **pac)
});
let r = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
match r {
Ok(_) => {}
Err(err) => report_error(ctx, &err),
};
}
unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
D: Aggregate<A, T>,
T: ToSql,
{
// Within the xFinal callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a: Option<A> = match aggregate_context(ctx, 0) {
Some(pac) => {
if (*pac as *mut A).is_null() {
None
} else {
let a = Box::from_raw(*pac);
Some(*a)
}
}
None => None,
};
let r = catch_unwind(|| {
let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx) as *mut D;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
(*boxed_aggr).finalize(a)
});
let t = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
}
#[cfg(feature = "window")]
unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
where
A: RefUnwindSafe + UnwindSafe,
W: WindowAggregate<A, T>,
T: ToSql,
{
// Within the xValue callback, it is customary to set N=0 in calls to
// sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
let a: Option<&A> = match aggregate_context(ctx, 0) {
Some(pac) => {
if (*pac as *mut A).is_null() {
None
} else {
let a = &**pac;
Some(a)
}
}
None => None,
};
let r = catch_unwind(|| {
let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx) as *mut W;
assert!(
!boxed_aggr.is_null(),
"Internal error - null aggregate pointer"
);
(*boxed_aggr).value(a)
});
let t = match r {
Err(_) => {
report_error(ctx, &Error::UnwindingPanic);
return;
}
Ok(r) => r,
};
let t = t.as_ref().map(|t| ToSql::to_sql(t));
match t {
Ok(Ok(ref value)) => set_result(ctx, value),
Ok(Err(err)) => report_error(ctx, &err),
Err(err) => report_error(ctx, err),
}
}
#[cfg(test)]
mod test {
use regex;
use self::regex::Regex;
use std::collections::HashMap;
use regex::Regex;
use std::f64::EPSILON;
use std::os::raw::c_double;
use crate::functions::{Aggregate, Context};
#[cfg(feature = "window")]
use crate::functions::WindowAggregate;
use crate::functions::{Aggregate, Context, FunctionFlags};
use crate::{Connection, Error, Result, NO_PARAMS};
fn half(ctx: &Context<'_>) -> Result<c_double> {
@@ -534,7 +721,13 @@ mod test {
#[test]
fn test_function_half() {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("half", 1, true, half).unwrap();
db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
)
.unwrap();
let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
assert!((3f64 - result.unwrap()).abs() < EPSILON);
@@ -543,7 +736,13 @@ mod test {
#[test]
fn test_remove_function() {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("half", 1, true, half).unwrap();
db.create_scalar_function(
"half",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
half,
)
.unwrap();
let result: Result<f64> = db.query_row("SELECT half(6)", NO_PARAMS, |r| r.get(0));
assert!((3f64 - result.unwrap()).abs() < EPSILON);
@@ -600,63 +799,14 @@ mod test {
END;",
)
.unwrap();
db.create_scalar_function("regexp", 2, true, regexp_with_auxilliary)
.unwrap();
let result: Result<bool> =
db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
r.get(0)
});
assert_eq!(true, result.unwrap());
let result: Result<i64> = db.query_row(
"SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1",
NO_PARAMS,
|r| r.get(0),
);
assert_eq!(2, result.unwrap());
}
#[test]
fn test_function_regexp_with_hashmap_cache() {
let db = Connection::open_in_memory().unwrap();
db.execute_batch(
"BEGIN;
CREATE TABLE foo (x string);
INSERT INTO foo VALUES ('lisa');
INSERT INTO foo VALUES ('lXsi');
INSERT INTO foo VALUES ('lisX');
END;",
db.create_scalar_function(
"regexp",
2,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
regexp_with_auxilliary,
)
.unwrap();
// This implementation of a regexp scalar function uses a captured HashMap
// to keep cached regular expressions around (even across multiple queries)
// until the function is removed.
let mut cached_regexes = HashMap::new();
db.create_scalar_function("regexp", 2, true, move |ctx| {
assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
let regex_s = ctx.get::<String>(0)?;
let entry = cached_regexes.entry(regex_s.clone());
let regex = {
use std::collections::hash_map::Entry::{Occupied, Vacant};
match entry {
Occupied(occ) => occ.into_mut(),
Vacant(vac) => match Regex::new(&regex_s) {
Ok(r) => vac.insert(r),
Err(err) => return Err(Error::UserFunctionError(Box::new(err))),
},
}
};
let text = ctx.get::<String>(1)?;
Ok(regex.is_match(&text))
})
.unwrap();
let result: Result<bool> =
db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", NO_PARAMS, |r| {
r.get(0)
@@ -676,16 +826,21 @@ mod test {
#[test]
fn test_varargs_function() {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("my_concat", -1, true, |ctx| {
let mut ret = String::new();
db.create_scalar_function(
"my_concat",
-1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
|ctx| {
let mut ret = String::new();
for idx in 0..ctx.len() {
let s = ctx.get::<String>(idx)?;
ret.push_str(&s);
}
for idx in 0..ctx.len() {
let s = ctx.get::<String>(idx)?;
ret.push_str(&s);
}
Ok(ret)
})
Ok(ret)
},
)
.unwrap();
for &(expected, query) in &[
@@ -701,7 +856,7 @@ mod test {
#[test]
fn test_get_aux_type_checking() {
let db = Connection::open_in_memory().unwrap();
db.create_scalar_function("example", 2, false, |ctx| {
db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
if !ctx.get::<bool>(1)? {
ctx.set_aux::<i64>(0, 100);
} else {
@@ -759,8 +914,13 @@ mod test {
#[test]
fn test_sum() {
let db = Connection::open_in_memory().unwrap();
db.create_aggregate_function("my_sum", 1, true, Sum)
.unwrap();
db.create_aggregate_function(
"my_sum",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
)
.unwrap();
// sum should return NULL when given no columns (contrast with count below)
let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
@@ -782,8 +942,13 @@ mod test {
#[test]
fn test_count() {
let db = Connection::open_in_memory().unwrap();
db.create_aggregate_function("my_count", -1, true, Count)
.unwrap();
db.create_aggregate_function(
"my_count",
-1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Count,
)
.unwrap();
// count should return 0 when given no columns (contrast with sum above)
let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
@@ -794,4 +959,64 @@ mod test {
let result: i64 = db.query_row(single_sum, NO_PARAMS, |r| r.get(0)).unwrap();
assert_eq!(2, result);
}
#[cfg(feature = "window")]
impl WindowAggregate<i64, Option<i64>> for Sum {
fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
*sum -= ctx.get::<i64>(0)?;
Ok(())
}
fn value(&self, sum: Option<&i64>) -> Result<Option<i64>> {
Ok(sum.copied())
}
}
#[test]
#[cfg(feature = "window")]
fn test_window() {
use fallible_iterator::FallibleIterator;
let db = Connection::open_in_memory().unwrap();
db.create_window_function(
"sumint",
1,
FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
Sum,
)
.unwrap();
db.execute_batch(
"CREATE TABLE t3(x, y);
INSERT INTO t3 VALUES('a', 4),
('b', 5),
('c', 3),
('d', 8),
('e', 1);",
)
.unwrap();
let mut stmt = db
.prepare(
"SELECT x, sumint(y) OVER (
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
) AS sum_y
FROM t3 ORDER BY x;",
)
.unwrap();
let results: Vec<(String, i64)> = stmt
.query(NO_PARAMS)
.unwrap()
.map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
.collect()
.unwrap();
let expected = vec![
("a".to_owned(), 9),
("b".to_owned(), 12),
("c".to_owned(), 16),
("d".to_owned(), 12),
("e".to_owned(), 9),
];
assert_eq!(expected, results);
}
}

View File

@@ -236,6 +236,7 @@ fn free_boxed_hook<F>(p: *mut c_void) {
mod test {
use super::Action;
use crate::Connection;
use lazy_static::lazy_static;
use std::sync::atomic::{AtomicBool, Ordering};
#[test]

View File

@@ -1,12 +1,12 @@
use std::ffi::CString;
use std::mem;
use std::os::raw::c_int;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_int};
#[cfg(feature = "load_extension")]
use std::path::Path;
use std::ptr;
use std::str;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Once, ONCE_INIT};
use std::sync::{Arc, Mutex, Once};
use super::ffi;
use super::{str_for_sqlite, str_to_cstring};
@@ -37,6 +37,7 @@ pub struct InnerConnection {
impl InnerConnection {
#[cfg(not(feature = "hooks"))]
#[allow(clippy::mutex_atomic)]
pub fn new(db: *mut ffi::sqlite3, owned: bool) -> InnerConnection {
InnerConnection {
db,
@@ -46,6 +47,7 @@ impl InnerConnection {
}
#[cfg(feature = "hooks")]
#[allow(clippy::mutex_atomic)]
pub fn new(db: *mut ffi::sqlite3, owned: bool) -> InnerConnection {
InnerConnection {
db,
@@ -83,13 +85,28 @@ impl InnerConnection {
}
unsafe {
let mut db: *mut ffi::sqlite3 = mem::uninitialized();
let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), z_vfs);
let mut db = MaybeUninit::uninit();
let r =
ffi::sqlite3_open_v2(c_path.as_ptr(), db.as_mut_ptr(), flags.bits(), z_vfs);
let db: *mut ffi::sqlite3 = db.assume_init();
if r != ffi::SQLITE_OK {
let e = if db.is_null() {
error_from_sqlite_code(r, None)
error_from_sqlite_code(r, Some(c_path.to_string_lossy().to_string()))
} else {
let e = error_from_handle(db, r);
let mut e = error_from_handle(db, r);
if let Error::SqliteFailure(
ffi::Error {
code: ffi::ErrorCode::CannotOpen,
..
},
Some(msg),
) = e
{
e = Error::SqliteFailure(
ffi::Error::new(r),
Some(format!("{}: {}", msg, c_path.to_string_lossy())),
);
}
ffi::sqlite3_close(db);
e
};
@@ -126,6 +143,7 @@ impl InnerConnection {
}
}
#[allow(clippy::mutex_atomic)]
pub fn close(&mut self) -> Result<()> {
if self.db.is_null() {
return Ok(());
@@ -181,25 +199,29 @@ impl InnerConnection {
#[cfg(feature = "load_extension")]
pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> {
use std::os::raw::c_char;
let dylib_str = super::path_to_cstring(dylib_path)?;
unsafe {
let mut errmsg: *mut c_char = mem::uninitialized();
let mut errmsg = MaybeUninit::uninit();
let r = if let Some(entry_point) = entry_point {
let c_entry = str_to_cstring(entry_point)?;
ffi::sqlite3_load_extension(
self.db,
dylib_str.as_ptr(),
c_entry.as_ptr(),
&mut errmsg,
errmsg.as_mut_ptr(),
)
} else {
ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg)
ffi::sqlite3_load_extension(
self.db,
dylib_str.as_ptr(),
ptr::null(),
errmsg.as_mut_ptr(),
)
};
if r == ffi::SQLITE_OK {
Ok(())
} else {
let errmsg: *mut c_char = errmsg.assume_init();
let message = super::errmsg_to_string(&*errmsg);
ffi::sqlite3_free(errmsg as *mut ::std::os::raw::c_void);
Err(error_from_sqlite_code(r, Some(message)))
@@ -212,8 +234,9 @@ impl InnerConnection {
}
pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> {
let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() };
let (c_sql, len, _) = str_for_sqlite(sql)?;
let mut c_stmt = MaybeUninit::uninit();
let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?;
let mut c_tail = MaybeUninit::uninit();
let r = unsafe {
if cfg!(feature = "unlock_notify") {
let mut rc;
@@ -222,8 +245,8 @@ impl InnerConnection {
self.db(),
c_sql,
len,
&mut c_stmt,
ptr::null_mut(),
c_stmt.as_mut_ptr(),
c_tail.as_mut_ptr(),
);
if !unlock_notify::is_locked(self.db, rc) {
break;
@@ -235,11 +258,24 @@ impl InnerConnection {
}
rc
} else {
ffi::sqlite3_prepare_v2(self.db(), c_sql, len, &mut c_stmt, ptr::null_mut())
ffi::sqlite3_prepare_v2(
self.db(),
c_sql,
len,
c_stmt.as_mut_ptr(),
c_tail.as_mut_ptr(),
)
}
};
self.decode_result(r)
.map(|_| Statement::new(conn, RawStatement::new(c_stmt)))
// If there is an error, *ppStmt is set to NULL.
self.decode_result(r)?;
// If the input text contains no SQL (if the input is an empty string or a
// comment) then *ppStmt is set to NULL.
let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() };
let c_tail: *const c_char = unsafe { c_tail.assume_init() };
// TODO ignore spaces, comments, ... at the end
let tail = !c_tail.is_null() && unsafe { c_tail != c_sql.offset(len as isize) };
Ok(Statement::new(conn, RawStatement::new(c_stmt, tail)))
}
pub fn changes(&mut self) -> usize {
@@ -250,7 +286,7 @@ impl InnerConnection {
unsafe { ffi::sqlite3_get_autocommit(self.db()) != 0 }
}
#[cfg(feature = "bundled")] // 3.8.6
#[cfg(feature = "modern_sqlite")] // 3.8.6
pub fn is_busy(&self) -> bool {
let db = self.db();
unsafe {
@@ -285,7 +321,7 @@ impl Drop for InnerConnection {
}
#[cfg(not(feature = "bundled"))]
static SQLITE_VERSION_CHECK: Once = ONCE_INIT;
static SQLITE_VERSION_CHECK: Once = Once::new();
#[cfg(not(feature = "bundled"))]
pub static BYPASS_VERSION_CHECK: AtomicBool = AtomicBool::new(false);
@@ -333,7 +369,7 @@ rusqlite was built against SQLite {} but the runtime SQLite version is {}. To fi
});
}
static SQLITE_INIT: Once = ONCE_INIT;
static SQLITE_INIT: Once = Once::new();
pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false);
fn ensure_safe_sqlite_threading_mode() -> Result<()> {

View File

@@ -2,7 +2,6 @@
//! expose an interface similar to [rust-postgres](https://github.com/sfackler/rust-postgres).
//!
//! ```rust
//! use rusqlite::types::ToSql;
//! use rusqlite::{params, Connection, Result};
//! use time::Timespec;
//!
@@ -58,12 +57,6 @@
pub use libsqlite3_sys as ffi;
#[macro_use]
extern crate bitflags;
#[cfg(any(test, feature = "vtab"))]
#[macro_use]
extern crate lazy_static;
use std::cell::RefCell;
use std::convert;
use std::default::Default;
@@ -105,6 +98,8 @@ pub mod backup;
pub mod blob;
mod busy;
mod cache;
#[cfg(feature = "collation")]
mod collation;
mod column;
pub mod config;
#[cfg(any(feature = "functions", feature = "vtab"))]
@@ -165,7 +160,7 @@ macro_rules! params {
$crate::NO_PARAMS
};
($($param:expr),+ $(,)?) => {
&[$(&$param as &dyn $crate::ToSql),+]
&[$(&$param as &dyn $crate::ToSql),+] as &[&dyn $crate::ToSql]
};
}
@@ -246,9 +241,9 @@ fn str_to_cstring(s: &str) -> Result<CString> {
/// The `sqlite3_destructor_type` item is always `SQLITE_TRANSIENT` unless
/// the string was empty (in which case it's `SQLITE_STATIC`, and the ptr is
/// static).
fn str_for_sqlite(s: &str) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> {
fn str_for_sqlite(s: &[u8]) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> {
let len = len_as_c_int(s.len())?;
if memchr::memchr(0, s.as_bytes()).is_none() {
if memchr::memchr(0, s).is_none() {
let (ptr, dtor_info) = if len != 0 {
(s.as_ptr() as *const c_char, ffi::SQLITE_TRANSIENT())
} else {
@@ -300,7 +295,7 @@ pub enum DatabaseName<'a> {
feature = "backup",
feature = "blob",
feature = "session",
feature = "bundled"
feature = "modern_sqlite"
))]
impl DatabaseName<'_> {
fn to_cstring(&self) -> Result<CString> {
@@ -336,6 +331,16 @@ impl Connection {
/// OpenFlags::SQLITE_OPEN_READ_WRITE |
/// OpenFlags::SQLITE_OPEN_CREATE)`.
///
/// ```rust,no_run
/// # use rusqlite::{Connection, Result};
/// fn open_my_db() -> Result<()> {
/// let path = "./my_db.db3";
/// let db = Connection::open(&path)?;
/// println!("{}", db.is_autocommit());
/// Ok(())
/// }
/// ```
///
/// # Failure
///
/// Will return `Err` if `path` cannot be converted to a C-compatible
@@ -481,7 +486,8 @@ impl Connection {
P: IntoIterator,
P::Item: ToSql,
{
self.prepare(sql).and_then(|mut stmt| stmt.execute(params))
self.prepare(sql)
.and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params)))
}
/// Convenience method to prepare and execute a single SQL statement with
@@ -507,8 +513,10 @@ impl Connection {
/// Will return `Err` if `sql` cannot be converted to a C-compatible string
/// or if the underlying SQLite call fails.
pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result<usize> {
self.prepare(sql)
.and_then(|mut stmt| stmt.execute_named(params))
self.prepare(sql).and_then(|mut stmt| {
stmt.check_no_tail()
.and_then(|_| stmt.execute_named(params))
})
}
/// Get the SQLite rowid of the most recent successful INSERT.
@@ -553,6 +561,7 @@ impl Connection {
F: FnOnce(&Row<'_>) -> Result<T>,
{
let mut stmt = self.prepare(sql)?;
stmt.check_no_tail()?;
stmt.query_row(params, f)
}
@@ -575,9 +584,8 @@ impl Connection {
F: FnOnce(&Row<'_>) -> Result<T>,
{
let mut stmt = self.prepare(sql)?;
let mut rows = stmt.query_named(params)?;
rows.get_expected_row().and_then(|r| f(&r))
stmt.check_no_tail()?;
stmt.query_row_named(params, f)
}
/// Convenience method to execute a query that is expected to return a
@@ -613,6 +621,7 @@ impl Connection {
E: convert::From<Error>,
{
let mut stmt = self.prepare(sql)?;
stmt.check_no_tail()?;
let mut rows = stmt.query(params)?;
rows.get_expected_row().map_err(E::from).and_then(|r| f(&r))
@@ -730,7 +739,11 @@ impl Connection {
///
/// You should not need to use this function. If you do need to, please
/// [open an issue on the rusqlite repository](https://github.com/jgallagher/rusqlite/issues) and describe
/// your use case. This function is unsafe because it gives you raw access
/// your use case.
///
/// # Safety
///
/// This function is unsafe because it gives you raw access
/// to the SQLite connection, and what you do with it could impact the
/// safety of this `Connection`.
pub unsafe fn handle(&self) -> *mut ffi::sqlite3 {
@@ -775,7 +788,7 @@ impl Connection {
}
/// Determine if all associated prepared statements have been reset.
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")] // 3.8.6
pub fn is_busy(&self) -> bool {
self.db.borrow().is_busy()
}
@@ -789,7 +802,7 @@ impl fmt::Debug for Connection {
}
}
bitflags! {
bitflags::bitflags! {
#[doc = "Flags for opening SQLite database connections."]
#[doc = "See [sqlite3_open_v2](http://www.sqlite.org/c3ref/open.html) for details."]
#[repr(C)]
@@ -826,8 +839,11 @@ impl Default for OpenFlags {
/// If you are encountering that panic _and_ can ensure that SQLite has been
/// initialized in either multi-thread or serialized mode, call this function
/// prior to attempting to open a connection and rusqlite's initialization
/// process will by skipped. This
/// function is unsafe because if you call it and SQLite has actually been
/// process will by skipped.
///
/// # Safety
///
/// This function is unsafe because if you call it and SQLite has actually been
/// configured to run in single-thread mode,
/// you may enounter memory errors or data corruption or any number of terrible
/// things that should not be possible when you're using Rust.
@@ -838,11 +854,13 @@ pub unsafe fn bypass_sqlite_initialization() {
/// rusqlite performs a one-time check that the runtime SQLite version is at
/// least as new as the version of SQLite found when rusqlite was built.
/// Bypassing this check may be dangerous; e.g., if you use features of SQLite
/// that are not present in the runtime
/// version. If you are sure the runtime version is compatible with the
/// that are not present in the runtime version.
///
/// # Safety
///
/// If you are sure the runtime version is compatible with the
/// build-time version for your usage, you can bypass the version check by
/// calling this function before
/// your first connection attempt.
/// calling this function before your first connection attempt.
pub unsafe fn bypass_sqlite_version_check() {
#[cfg(not(feature = "bundled"))]
inner_connection::BYPASS_VERSION_CHECK.store(true, Ordering::Relaxed);
@@ -867,7 +885,7 @@ impl InterruptHandle {
}
}
#[cfg(feature = "bundled")] // 3.7.10
#[cfg(feature = "modern_sqlite")] // 3.7.10
unsafe fn db_filename(db: *mut ffi::sqlite3) -> Option<PathBuf> {
let db_name = DatabaseName::Main.to_cstring().unwrap();
let db_filename = ffi::sqlite3_db_filename(db, db_name.as_ptr());
@@ -877,20 +895,19 @@ unsafe fn db_filename(db: *mut ffi::sqlite3) -> Option<PathBuf> {
CStr::from_ptr(db_filename).to_str().ok().map(PathBuf::from)
}
}
#[cfg(not(feature = "bundled"))]
#[cfg(not(feature = "modern_sqlite"))]
unsafe fn db_filename(_: *mut ffi::sqlite3) -> Option<PathBuf> {
None
}
#[cfg(test)]
mod test {
use self::tempdir::TempDir;
pub use super::*;
use super::*;
use crate::ffi;
use fallible_iterator::FallibleIterator;
pub use std::error::Error as StdError;
pub use std::fmt;
use tempdir;
use std::error::Error as StdError;
use std::fmt;
use tempfile;
// this function is never called, but is still type checked; in
// particular, calls with specific instantiations will require
@@ -913,7 +930,7 @@ mod test {
#[test]
fn test_concurrent_transactions_busy_commit() {
use std::time::Duration;
let tmp = TempDir::new("locked").unwrap();
let tmp = tempfile::tempdir().unwrap();
let path = tmp.path().join("transactions.db3");
Connection::open(&path)
@@ -925,8 +942,9 @@ mod test {
)
.expect("create temp db");
let mut db1 = Connection::open(&path).unwrap();
let mut db2 = Connection::open(&path).unwrap();
let mut db1 =
Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap();
let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).unwrap();
db1.busy_timeout(Duration::from_millis(0)).unwrap();
db2.busy_timeout(Duration::from_millis(0)).unwrap();
@@ -958,7 +976,7 @@ mod test {
#[test]
fn test_persistence() {
let temp_dir = TempDir::new("test_open_file").unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("test.db3");
{
@@ -985,6 +1003,26 @@ mod test {
assert!(db.close().is_ok());
}
#[test]
fn test_open_failure() {
let filename = "no_such_file.db";
let result = Connection::open_with_flags(filename, OpenFlags::SQLITE_OPEN_READ_ONLY);
assert!(!result.is_ok());
let err = result.err().unwrap();
if let Error::SqliteFailure(e, Some(msg)) = err {
assert_eq!(ErrorCode::CannotOpen, e.code);
assert_eq!(ffi::SQLITE_CANTOPEN, e.extended_code);
assert!(
msg.contains(filename),
"error message '{}' does not contain '{}'",
msg,
filename
);
} else {
panic!("SqliteFailure expected");
}
}
#[test]
fn test_close_retry() {
let db = checked_memory_handle();
@@ -994,23 +1032,25 @@ mod test {
// statement first.
let raw_stmt = {
use super::str_to_cstring;
use std::mem;
use std::mem::MaybeUninit;
use std::os::raw::c_int;
use std::ptr;
let raw_db = db.db.borrow_mut().db;
let sql = "SELECT 1";
let mut raw_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() };
let mut raw_stmt = MaybeUninit::uninit();
let cstring = str_to_cstring(sql).unwrap();
let rc = unsafe {
ffi::sqlite3_prepare_v2(
raw_db,
str_to_cstring(sql).unwrap().as_ptr(),
cstring.as_ptr(),
(sql.len() + 1) as c_int,
&mut raw_stmt,
raw_stmt.as_mut_ptr(),
ptr::null_mut(),
)
};
assert_eq!(rc, ffi::SQLITE_OK);
let raw_stmt: *mut ffi::sqlite3_stmt = unsafe { raw_stmt.assume_init() };
raw_stmt
};
@@ -1087,6 +1127,22 @@ mod test {
}
}
#[test]
#[cfg(feature = "extra_check")]
fn test_execute_multiple() {
let db = checked_memory_handle();
let err = db
.execute(
"CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)",
NO_PARAMS,
)
.unwrap_err();
match err {
Error::MultipleStatement => (),
_ => panic!("Unexpected error: {}", err),
}
}
#[test]
fn test_prepare_column_names() {
let db = checked_memory_handle();
@@ -1284,7 +1340,7 @@ mod test {
}
#[test]
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn test_is_busy() {
let db = checked_memory_handle();
assert!(!db.is_busy());
@@ -1313,11 +1369,11 @@ mod test {
fn test_notnull_constraint_error() {
// extended error codes for constraints were added in SQLite 3.7.16; if we're
// running on our bundled version, we know the extended error code exists.
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn check_extended_code(extended_code: c_int) {
assert_eq!(extended_code, ffi::SQLITE_CONSTRAINT_NOTNULL);
}
#[cfg(not(feature = "bundled"))]
#[cfg(not(feature = "modern_sqlite"))]
fn check_extended_code(_extended_code: c_int) {}
let db = checked_memory_handle();
@@ -1352,10 +1408,15 @@ mod test {
let interrupt_handle = db.get_interrupt_handle();
db.create_scalar_function("interrupt", 0, false, move |_| {
interrupt_handle.interrupt();
Ok(0)
})
db.create_scalar_function(
"interrupt",
0,
crate::functions::FunctionFlags::default(),
move |_| {
interrupt_handle.interrupt();
Ok(0)
},
)
.unwrap();
let mut stmt = db
@@ -1367,7 +1428,6 @@ mod test {
match result.unwrap_err() {
Error::SqliteFailure(err, _) => {
assert_eq!(err.code, ErrorCode::OperationInterrupted);
return;
}
err => {
panic!("Unexpected error {}", err);
@@ -1437,8 +1497,8 @@ mod test {
impl fmt::Display for CustomError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> ::std::result::Result<(), fmt::Error> {
match *self {
CustomError::SomeError => write!(f, "{}", self.description()),
CustomError::Sqlite(ref se) => write!(f, "{}: {}", self.description(), se),
CustomError::SomeError => write!(f, "my custom error"),
CustomError::Sqlite(ref se) => write!(f, "my custom error: {}", se),
}
}
}
@@ -1504,7 +1564,7 @@ mod test {
.collect();
match bad_type.unwrap_err() {
Error::InvalidColumnType(_, _) => (),
Error::InvalidColumnType(..) => (),
err => panic!("Unexpected error {}", err),
}
@@ -1559,7 +1619,7 @@ mod test {
.collect();
match bad_type.unwrap_err() {
CustomError::Sqlite(Error::InvalidColumnType(_, _)) => (),
CustomError::Sqlite(Error::InvalidColumnType(..)) => (),
err => panic!("Unexpected error {}", err),
}
@@ -1616,7 +1676,7 @@ mod test {
});
match bad_type.unwrap_err() {
CustomError::Sqlite(Error::InvalidColumnType(_, _)) => (),
CustomError::Sqlite(Error::InvalidColumnType(..)) => (),
err => panic!("Unexpected error {}", err),
}
@@ -1665,5 +1725,26 @@ mod test {
})
.unwrap();
}
#[test]
fn test_params() {
let db = checked_memory_handle();
db.query_row(
"SELECT
?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?, ?, ?, ?, ?, ?, ?,
?, ?, ?, ?;",
params![
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1,
],
|r| {
assert_eq!(1, r.get_unwrap::<_, i32>(0));
Ok(())
},
)
.unwrap();
}
}
}

View File

@@ -86,6 +86,7 @@ impl Sql {
self.push_real(r);
}
ValueRef::Text(s) => {
let s = std::str::from_utf8(s)?;
self.push_string_literal(s);
}
_ => {
@@ -325,7 +326,7 @@ mod test {
}
#[test]
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn pragma_func_query_value() {
use crate::NO_PARAMS;
@@ -378,7 +379,7 @@ mod test {
}
#[test]
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn pragma_func() {
let db = Connection::open_in_memory().unwrap();
let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?)").unwrap();

View File

@@ -7,11 +7,15 @@ use std::ptr;
// Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
#[derive(Debug)]
pub struct RawStatement(*mut ffi::sqlite3_stmt);
pub struct RawStatement(*mut ffi::sqlite3_stmt, bool);
impl RawStatement {
pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement {
RawStatement(stmt)
pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement {
RawStatement(stmt, tail)
}
pub fn is_null(&self) -> bool {
self.0.is_null()
}
pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt {
@@ -37,8 +41,21 @@ impl RawStatement {
}
}
pub fn column_name(&self, idx: usize) -> &CStr {
unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.0, idx as c_int)) }
pub fn column_name(&self, idx: usize) -> Option<&CStr> {
let idx = idx as c_int;
if idx < 0 || idx >= self.column_count() as c_int {
return None;
}
unsafe {
let ptr = ffi::sqlite3_column_name(self.0, idx);
// If ptr is null here, it's an OOM, so there's probably nothing
// meaningful we can do. Just assert instead of returning None.
assert!(
!ptr.is_null(),
"Null pointer from sqlite3_column_name: Out of memory?"
);
Some(CStr::from_ptr(ptr))
}
}
pub fn step(&self) -> c_int {
@@ -82,8 +99,12 @@ impl RawStatement {
unsafe { ffi::sqlite3_clear_bindings(self.0) }
}
pub fn sql(&self) -> &CStr {
unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.0)) }
pub fn sql(&self) -> Option<&CStr> {
if self.0.is_null() {
None
} else {
Some(unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.0)) })
}
}
pub fn finalize(mut self) -> c_int {
@@ -96,20 +117,19 @@ impl RawStatement {
r
}
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")] // 3.7.4
pub fn readonly(&self) -> bool {
unsafe { ffi::sqlite3_stmt_readonly(self.0) != 0 }
}
#[cfg(feature = "bundled")]
pub fn expanded_sql(&self) -> Option<&CStr> {
unsafe {
let ptr = ffi::sqlite3_expanded_sql(self.0);
if ptr.is_null() {
None
} else {
Some(CStr::from_ptr(ptr))
}
/// `CStr` must be freed
#[cfg(feature = "modern_sqlite")] // 3.14.0
pub unsafe fn expanded_sql(&self) -> Option<&CStr> {
let ptr = ffi::sqlite3_expanded_sql(self.0);
if ptr.is_null() {
None
} else {
Some(CStr::from_ptr(ptr))
}
}
@@ -117,6 +137,10 @@ impl RawStatement {
assert!(!self.0.is_null());
unsafe { ffi::sqlite3_stmt_status(self.0, status as i32, reset as i32) }
}
pub fn has_tail(&self) -> bool {
self.1
}
}
impl Drop for RawStatement {

View File

@@ -223,15 +223,27 @@ impl<'stmt> Row<'stmt> {
let idx = idx.idx(self.stmt)?;
let value = self.stmt.value_ref(idx);
FromSql::column_result(value).map_err(|err| match err {
FromSqlError::InvalidType => Error::InvalidColumnType(idx, value.data_type()),
FromSqlError::InvalidType => Error::InvalidColumnType(
idx,
self.stmt.column_name_unwrap(idx).into(),
value.data_type(),
),
FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
FromSqlError::Other(err) => {
Error::FromSqlConversionFailure(idx as usize, value.data_type(), err)
}
#[cfg(feature = "i128_blob")]
FromSqlError::InvalidI128Size(_) => Error::InvalidColumnType(idx, value.data_type()),
FromSqlError::InvalidI128Size(_) => Error::InvalidColumnType(
idx,
self.stmt.column_name_unwrap(idx).into(),
value.data_type(),
),
#[cfg(feature = "uuid")]
FromSqlError::InvalidUuidSize(_) => Error::InvalidColumnType(idx, value.data_type()),
FromSqlError::InvalidUuidSize(_) => Error::InvalidColumnType(
idx,
self.stmt.column_name_unwrap(idx).into(),
value.data_type(),
),
})
}

View File

@@ -4,7 +4,7 @@
use std::ffi::CStr;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem;
use std::mem::MaybeUninit;
use std::os::raw::{c_char, c_int, c_uchar, c_void};
use std::panic::{catch_unwind, RefUnwindSafe};
use std::ptr;
@@ -43,8 +43,9 @@ impl Session<'_> {
let db = db.db.borrow_mut().db;
let mut s: *mut ffi::sqlite3_session = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3session_create(db, name.as_ptr(), &mut s) });
let mut s = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3session_create(db, name.as_ptr(), s.as_mut_ptr()) });
let s: *mut ffi::sqlite3_session = unsafe { s.assume_init() };
Ok(Session {
phantom: PhantomData,
@@ -65,7 +66,6 @@ impl Session<'_> {
where
F: Fn(&str) -> bool + RefUnwindSafe,
{
use std::ffi::CStr;
use std::str;
let boxed_filter: *mut F = p_arg as *mut F;
@@ -113,8 +113,9 @@ impl Session<'_> {
/// Generate a Changeset
pub fn changeset(&mut self) -> Result<Changeset> {
let mut n = 0;
let mut cs: *mut c_void = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3session_changeset(self.s, &mut n, &mut cs) });
let mut cs = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3session_changeset(self.s, &mut n, cs.as_mut_ptr()) });
let cs: *mut c_void = unsafe { cs.assume_init() };
Ok(Changeset { cs, n })
}
@@ -134,8 +135,9 @@ impl Session<'_> {
/// Generate a Patchset
pub fn patchset(&mut self) -> Result<Changeset> {
let mut n = 0;
let mut ps: *mut c_void = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3session_patchset(self.s, &mut n, &mut ps) });
let mut ps = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3session_patchset(self.s, &mut n, ps.as_mut_ptr()) });
let ps: *mut c_void = unsafe { ps.assume_init() };
// TODO Validate: same struct
Ok(Changeset { cs: ps, n })
}
@@ -158,9 +160,10 @@ impl Session<'_> {
let from = from.to_cstring()?;
let table = str_to_cstring(table)?.as_ptr();
unsafe {
let mut errmsg: *mut c_char = mem::uninitialized();
let r = ffi::sqlite3session_diff(self.s, from.as_ptr(), table, &mut errmsg);
let mut errmsg = MaybeUninit::uninit();
let r = ffi::sqlite3session_diff(self.s, from.as_ptr(), table, errmsg.as_mut_ptr());
if r != ffi::SQLITE_OK {
let errmsg: *mut c_char = errmsg.assume_init();
let message = errmsg_to_string(&*errmsg);
ffi::sqlite3_free(errmsg as *mut ::std::os::raw::c_void);
return Err(error_from_sqlite_code(r, Some(message)));
@@ -255,15 +258,17 @@ impl Changeset {
/// Invert a changeset
pub fn invert(&self) -> Result<Changeset> {
let mut n = 0;
let mut cs: *mut c_void = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3changeset_invert(self.n, self.cs, &mut n, &mut cs) });
let mut cs = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3changeset_invert(self.n, self.cs, &mut n, cs.as_mut_ptr()) });
let cs: *mut c_void = unsafe { cs.assume_init() };
Ok(Changeset { cs, n })
}
/// Create an iterator to traverse a changeset
pub fn iter(&self) -> Result<ChangesetIter<'_>> {
let mut it: *mut ffi::sqlite3_changeset_iter = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3changeset_start(&mut it, self.n, self.cs) });
let mut it = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3changeset_start(it.as_mut_ptr(), self.n, self.cs) });
let it: *mut ffi::sqlite3_changeset_iter = unsafe { it.assume_init() };
Ok(ChangesetIter {
phantom: PhantomData,
it,
@@ -274,8 +279,11 @@ impl Changeset {
/// Concatenate two changeset objects
pub fn concat(a: &Changeset, b: &Changeset) -> Result<Changeset> {
let mut n = 0;
let mut cs: *mut c_void = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3changeset_concat(a.n, a.cs, b.n, b.cs, &mut n, &mut cs) });
let mut cs = MaybeUninit::uninit();
check!(unsafe {
ffi::sqlite3changeset_concat(a.n, a.cs, b.n, b.cs, &mut n, cs.as_mut_ptr())
});
let cs: *mut c_void = unsafe { cs.assume_init() };
Ok(Changeset { cs, n })
}
}
@@ -297,16 +305,16 @@ pub struct ChangesetIter<'changeset> {
impl ChangesetIter<'_> {
/// Create an iterator on `input`
pub fn start_strm<'input>(input: &'input mut dyn Read) -> Result<ChangesetIter<'input>> {
let input_ref = &input;
let mut it: *mut ffi::sqlite3_changeset_iter = unsafe { mem::uninitialized() };
pub fn start_strm<'input>(input: &&'input mut dyn Read) -> Result<ChangesetIter<'input>> {
let mut it = MaybeUninit::uninit();
check!(unsafe {
ffi::sqlite3changeset_start_strm(
&mut it,
it.as_mut_ptr(),
Some(x_input),
input_ref as *const &mut dyn Read as *mut c_void,
input as *const &mut dyn Read as *mut c_void,
)
});
let it: *mut ffi::sqlite3_changeset_iter = unsafe { it.assume_init() };
Ok(ChangesetIter {
phantom: PhantomData,
it,
@@ -386,12 +394,13 @@ impl ChangesetItem {
/// `SQLITE_CHANGESET_CONFLICT` conflict handler callback.
pub fn conflict(&self, col: usize) -> Result<ValueRef<'_>> {
unsafe {
let mut p_value: *mut ffi::sqlite3_value = mem::uninitialized();
let mut p_value = MaybeUninit::uninit();
check!(ffi::sqlite3changeset_conflict(
self.it,
col as i32,
&mut p_value
p_value.as_mut_ptr()
));
let p_value: *mut ffi::sqlite3_value = p_value.assume_init();
Ok(ValueRef::from_value(p_value))
}
}
@@ -414,8 +423,13 @@ impl ChangesetItem {
/// `SQLITE_INSERT`.
pub fn new_value(&self, col: usize) -> Result<ValueRef<'_>> {
unsafe {
let mut p_value: *mut ffi::sqlite3_value = mem::uninitialized();
check!(ffi::sqlite3changeset_new(self.it, col as i32, &mut p_value));
let mut p_value = MaybeUninit::uninit();
check!(ffi::sqlite3changeset_new(
self.it,
col as i32,
p_value.as_mut_ptr()
));
let p_value: *mut ffi::sqlite3_value = p_value.assume_init();
Ok(ValueRef::from_value(p_value))
}
}
@@ -426,8 +440,13 @@ impl ChangesetItem {
/// `SQLITE_UPDATE`.
pub fn old_value(&self, col: usize) -> Result<ValueRef<'_>> {
unsafe {
let mut p_value: *mut ffi::sqlite3_value = mem::uninitialized();
check!(ffi::sqlite3changeset_old(self.it, col as i32, &mut p_value));
let mut p_value = MaybeUninit::uninit();
check!(ffi::sqlite3changeset_old(
self.it,
col as i32,
p_value.as_mut_ptr()
));
let p_value: *mut ffi::sqlite3_value = p_value.assume_init();
Ok(ValueRef::from_value(p_value))
}
}
@@ -438,14 +457,15 @@ impl ChangesetItem {
let mut code = 0;
let mut indirect = 0;
let tab = unsafe {
let mut pz_tab: *const c_char = mem::uninitialized();
let mut pz_tab = MaybeUninit::uninit();
check!(ffi::sqlite3changeset_op(
self.it,
&mut pz_tab,
pz_tab.as_mut_ptr(),
&mut number_of_columns,
&mut code,
&mut indirect
));
let pz_tab: *const c_char = pz_tab.assume_init();
CStr::from_ptr(pz_tab)
};
let table_name = tab.to_str()?;
@@ -461,12 +481,13 @@ impl ChangesetItem {
pub fn pk(&self) -> Result<&[u8]> {
let mut number_of_columns = 0;
unsafe {
let mut pks: *mut c_uchar = mem::uninitialized();
let mut pks = MaybeUninit::uninit();
check!(ffi::sqlite3changeset_pk(
self.it,
&mut pks,
pks.as_mut_ptr(),
&mut number_of_columns
));
let pks: *mut c_uchar = pks.assume_init();
Ok(from_raw_parts(pks, number_of_columns as usize))
}
}
@@ -480,8 +501,9 @@ pub struct Changegroup {
impl Changegroup {
pub fn new() -> Result<Self> {
let mut cg: *mut ffi::sqlite3_changegroup = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3changegroup_new(&mut cg) });
let mut cg = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3changegroup_new(cg.as_mut_ptr()) });
let cg: *mut ffi::sqlite3_changegroup = unsafe { cg.assume_init() };
Ok(Changegroup { cg })
}
@@ -507,8 +529,9 @@ impl Changegroup {
/// Obtain a composite Changeset
pub fn output(&mut self) -> Result<Changeset> {
let mut n = 0;
let mut output: *mut c_void = unsafe { mem::uninitialized() };
check!(unsafe { ffi::sqlite3changegroup_output(self.cg, &mut n, &mut output) });
let mut output = MaybeUninit::uninit();
check!(unsafe { ffi::sqlite3changegroup_output(self.cg, &mut n, output.as_mut_ptr()) });
let output: *mut c_void = unsafe { output.assume_init() };
Ok(Changeset { cs: output, n })
}
@@ -648,7 +671,6 @@ where
F: Fn(&str) -> bool + Send + RefUnwindSafe + 'static,
C: Fn(ConflictType, ChangesetItem) -> ConflictAction + Send + RefUnwindSafe + 'static,
{
use std::ffi::CStr;
use std::str;
let tuple: *mut (Option<F>, C) = p_ctx as *mut (Option<F>, C);
@@ -691,7 +713,7 @@ unsafe extern "C" fn x_input(p_in: *mut c_void, data: *mut c_void, len: *mut c_i
if p_in.is_null() {
return ffi::SQLITE_MISUSE;
}
let bytes: &mut [u8] = from_raw_parts_mut(data as *mut u8, len as usize);
let bytes: &mut [u8] = from_raw_parts_mut(data as *mut u8, *len as usize);
let input = p_in as *mut &mut dyn Read;
match (*input).read(bytes) {
Ok(n) => {
@@ -720,6 +742,7 @@ unsafe extern "C" fn x_output(p_out: *mut c_void, data: *const c_void, len: c_in
#[cfg(test)]
mod test {
use fallible_streaming_iterator::FallibleStreamingIterator;
use std::io::Read;
use std::sync::atomic::{AtomicBool, Ordering};
use super::{Changeset, ChangesetIter, ConflictAction, ConflictType, Session};
@@ -785,8 +808,8 @@ mod test {
assert!(!output.is_empty());
assert_eq!(14, output.len());
let mut input = output.as_slice();
let mut iter = ChangesetIter::start_strm(&mut input).unwrap();
let input: &mut dyn Read = &mut output.as_slice();
let mut iter = ChangesetIter::start_strm(&input).unwrap();
let item = iter.next().unwrap();
assert!(item.is_some());
}
@@ -799,7 +822,7 @@ mod test {
db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")
.unwrap();
lazy_static! {
lazy_static::lazy_static! {
static ref CALLED: AtomicBool = AtomicBool::new(false);
}
db.apply(
@@ -844,8 +867,9 @@ mod test {
db.execute_batch("CREATE TABLE foo(t TEXT PRIMARY KEY NOT NULL);")
.unwrap();
let mut input = output.as_slice();
db.apply_strm(
&mut output.as_slice(),
&mut input,
None::<fn(&str) -> bool>,
|_conflict_type, _item| ConflictAction::SQLITE_CHANGESET_OMIT,
)

View File

@@ -45,7 +45,7 @@ impl Statement<'_> {
///
/// Will return `Err` if binding parameters fails, the executed statement
/// returns rows (in which case `query` should be used instead), or the
/// underling SQLite call fails.
/// underlying SQLite call fails.
pub fn execute<P>(&mut self, params: P) -> Result<usize>
where
P: IntoIterator,
@@ -89,7 +89,7 @@ impl Statement<'_> {
///
/// Will return `Err` if binding parameters fails, the executed statement
/// returns rows (in which case `query` should be used instead), or the
/// underling SQLite call fails.
/// underlying SQLite call fails.
pub fn execute_named(&mut self, params: &[(&str, &dyn ToSql)]) -> Result<usize> {
self.bind_parameters_named(params)?;
self.execute_with_bound_parameters()
@@ -378,6 +378,29 @@ impl Statement<'_> {
rows.get_expected_row().and_then(|r| f(&r))
}
/// Convenience method to execute a query with named parameter(s) that is
/// expected to return a single row.
///
/// If the query returns more than one row, all rows except the first are
/// ignored.
///
/// Returns `Err(QueryReturnedNoRows)` if no results are returned. If the
/// query truly is optional, you can call `.optional()` on the result of
/// this to get a `Result<Option<T>>`.
///
/// # Failure
///
/// Will return `Err` if `sql` cannot be converted to a C-compatible string
/// or if the underlying SQLite call fails.
pub fn query_row_named<T, F>(&mut self, params: &[(&str, &dyn ToSql)], f: F) -> Result<T>
where
F: FnOnce(&Row<'_>) -> Result<T>,
{
let mut rows = self.query_named(params)?;
rows.get_expected_row().and_then(|r| f(&r))
}
/// Consumes the statement.
///
/// Functionally equivalent to the `Drop` implementation, but allows
@@ -488,6 +511,7 @@ impl Statement<'_> {
}
fn execute_with_bound_parameters(&mut self) -> Result<usize> {
self.check_update()?;
let r = self.stmt.step();
self.stmt.reset();
match r {
@@ -504,18 +528,18 @@ impl Statement<'_> {
}
fn finalize_(&mut self) -> Result<()> {
let mut stmt = RawStatement::new(ptr::null_mut());
let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt);
self.conn.decode_result(stmt.finalize())
}
#[cfg(not(feature = "bundled"))]
#[cfg(not(feature = "modern_sqlite"))]
#[inline]
fn check_readonly(&self) -> Result<()> {
Ok(())
}
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
#[inline]
fn check_readonly(&self) -> Result<()> {
/*if !self.stmt.readonly() { does not work for PRAGMA
@@ -524,14 +548,43 @@ impl Statement<'_> {
Ok(())
}
#[cfg(all(feature = "modern_sqlite", feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
if self.column_count() > 0 || self.stmt.readonly() {
return Err(Error::ExecuteReturnedResults);
}
Ok(())
}
#[cfg(all(not(feature = "modern_sqlite"), feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
if self.column_count() > 0 {
return Err(Error::ExecuteReturnedResults);
}
Ok(())
}
#[cfg(not(feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
Ok(())
}
/// Returns a string containing the SQL text of prepared statement with
/// bound parameters expanded.
#[cfg(feature = "bundled")]
pub fn expanded_sql(&self) -> Option<&str> {
#[cfg(feature = "modern_sqlite")]
pub fn expanded_sql(&self) -> Option<String> {
unsafe {
self.stmt
.expanded_sql()
.map(|s| str::from_utf8_unchecked(s.to_bytes()))
match self.stmt.expanded_sql() {
Some(s) => {
let sql = str::from_utf8_unchecked(s.to_bytes()).to_owned();
ffi::sqlite3_free(s.as_ptr() as *mut _);
Some(sql)
}
_ => None,
}
}
}
@@ -545,11 +598,26 @@ impl Statement<'_> {
pub fn reset_status(&self, status: StatementStatus) -> i32 {
self.stmt.get_status(status, true)
}
#[cfg(feature = "extra_check")]
pub(crate) fn check_no_tail(&self) -> Result<()> {
if self.stmt.has_tail() {
Err(Error::MultipleStatement)
} else {
Ok(())
}
}
#[cfg(not(feature = "extra_check"))]
#[inline]
pub(crate) fn check_no_tail(&self) -> Result<()> {
Ok(())
}
}
impl Into<RawStatement> for Statement<'_> {
fn into(mut self) -> RawStatement {
let mut stmt = RawStatement::new(ptr::null_mut());
let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt);
stmt
}
@@ -557,7 +625,11 @@ impl Into<RawStatement> for Statement<'_> {
impl fmt::Debug for Statement<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sql = str::from_utf8(self.stmt.sql().to_bytes());
let sql = if self.stmt.is_null() {
Ok("")
} else {
str::from_utf8(self.stmt.sql().unwrap().to_bytes())
};
f.debug_struct("Statement")
.field("conn", self.conn)
.field("stmt", &self.stmt)
@@ -599,10 +671,7 @@ impl Statement<'_> {
CStr::from_ptr(text as *const c_char)
};
// sqlite3_column_text returns UTF8 data, so our unwrap here should be fine.
let s = s
.to_str()
.expect("sqlite3_column_text returned invalid UTF-8");
let s = s.to_bytes();
ValueRef::Text(s)
}
ffi::SQLITE_BLOB => {
@@ -716,14 +785,13 @@ mod test {
.unwrap();
stmt.execute_named(&[(":name", &"one")]).unwrap();
let mut stmt = db
.prepare("SELECT COUNT(*) FROM test WHERE name = :name")
.unwrap();
assert_eq!(
1i32,
db.query_row_named::<i32, _>(
"SELECT COUNT(*) FROM test WHERE name = :name",
&[(":name", &"one")],
|r| r.get(0)
)
.unwrap()
stmt.query_row_named::<i32, _>(&[(":name", &"one")], |r| r.get(0))
.unwrap()
);
}
@@ -951,12 +1019,12 @@ mod test {
}
#[test]
#[cfg(feature = "bundled")]
#[cfg(feature = "modern_sqlite")]
fn test_expanded_sql() {
let db = Connection::open_in_memory().unwrap();
let stmt = db.prepare("SELECT ?").unwrap();
stmt.bind_parameter(&1, 1).unwrap();
assert_eq!(Some("SELECT 1"), stmt.expanded_sql());
assert_eq!(Some("SELECT 1".to_owned()), stmt.expanded_sql());
}
#[test]
@@ -983,7 +1051,7 @@ mod test {
use std::collections::BTreeSet;
let data: BTreeSet<String> = ["one", "two", "three"]
.iter()
.map(|s| s.to_string())
.map(|s| (*s).to_string())
.collect();
db.query_row("SELECT ?1, ?2, ?3", &data, |row| row.get::<_, String>(0))
.unwrap();
@@ -994,4 +1062,35 @@ mod test {
db.query_row("SELECT ?1, ?2, ?3", data.iter(), |row| row.get::<_, u8>(0))
.unwrap();
}
#[test]
fn test_empty_stmt() {
let conn = Connection::open_in_memory().unwrap();
let mut stmt = conn.prepare("").unwrap();
assert_eq!(0, stmt.column_count());
assert!(stmt.parameter_index("test").is_ok());
assert!(stmt.step().is_err());
stmt.reset();
assert!(stmt.execute(NO_PARAMS).is_err());
}
#[test]
fn test_comment_stmt() {
let conn = Connection::open_in_memory().unwrap();
conn.prepare("/*SELECT 1;*/").unwrap();
}
#[test]
fn test_comment_and_sql_stmt() {
let conn = Connection::open_in_memory().unwrap();
let stmt = conn.prepare("/*...*/ SELECT 1;").unwrap();
assert_eq!(1, stmt.column_count());
}
#[test]
fn test_semi_colon_stmt() {
let conn = Connection::open_in_memory().unwrap();
let stmt = conn.prepare(";").unwrap();
assert_eq!(0, stmt.column_count());
}
}

View File

@@ -12,6 +12,9 @@ use crate::error::error_from_sqlite_code;
use crate::{Connection, Result};
/// Set up the process-wide SQLite error logging callback.
///
/// # Safety
///
/// This function is marked unsafe for two reasons:
///
/// * The function is not threadsafe. No other SQLite calls may be made while
@@ -122,6 +125,7 @@ impl Connection {
#[cfg(test)]
mod test {
use lazy_static::lazy_static;
use std::sync::Mutex;
use std::time::Duration;

View File

@@ -1,9 +1,8 @@
//! Convert most of the [Time Strings](http://sqlite.org/lang_datefunc.html) to chrono types.
use chrono;
use std::borrow::Cow;
use self::chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use chrono::{DateTime, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef};
use crate::Result;
@@ -54,7 +53,7 @@ impl FromSql for NaiveTime {
}
/// ISO 8601 combined date and time without timezone =>
/// "YYYY-MM-DD HH:MM:SS.SSS"
/// "YYYY-MM-DDTHH:MM:SS.SSS"
impl ToSql for NaiveDateTime {
fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
let date_str = self.format("%Y-%m-%dT%H:%M:%S%.f").to_string();
@@ -129,10 +128,8 @@ impl FromSql for DateTime<Local> {
#[cfg(test)]
mod test {
use super::chrono::{
DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc,
};
use crate::{Connection, Result, NO_PARAMS};
use chrono::{DateTime, Duration, Local, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc};
fn checked_memory_handle() -> Connection {
let db = Connection::open_in_memory().unwrap();

View File

@@ -36,7 +36,7 @@ impl PartialEq for FromSqlError {
(FromSqlError::InvalidI128Size(s1), FromSqlError::InvalidI128Size(s2)) => s1 == s2,
#[cfg(feature = "uuid")]
(FromSqlError::InvalidUuidSize(s1), FromSqlError::InvalidUuidSize(s2)) => s1 == s2,
(_, _) => false,
(..) => false,
}
}
}
@@ -60,6 +60,7 @@ impl fmt::Display for FromSqlError {
}
impl Error for FromSqlError {
#[allow(deprecated)]
fn description(&self) -> &str {
match *self {
FromSqlError::InvalidType => "invalid type",
@@ -166,6 +167,24 @@ impl FromSql for String {
}
}
impl FromSql for Box<str> {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
value.as_str().map(Into::into)
}
}
impl FromSql for std::rc::Rc<str> {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
value.as_str().map(Into::into)
}
}
impl FromSql for std::sync::Arc<str> {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
value.as_str().map(Into::into)
}
}
impl FromSql for Vec<u8> {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
value.as_blob().map(|b| b.to_vec())

View File

@@ -42,10 +42,6 @@
//! Ok(as_f64.into())
//! }
//! }
//!
//! # // Prevent this doc test from being wrapped in a `fn main()` so that it
//! # // will compile.
//! # fn main() {}
//! ```
//!
//! `ToSql` and `FromSql` are also implemented for `Option<T>` where `T`
@@ -78,7 +74,7 @@ mod value_ref;
/// ```rust,no_run
/// # use rusqlite::{Connection, Result};
/// # use rusqlite::types::{Null};
/// fn main() {}
///
/// fn insert_null(conn: &Connection) -> Result<usize> {
/// conn.execute("INSERT INTO people (name) VALUES (?)", &[Null])
/// }
@@ -229,7 +225,7 @@ mod test {
fn test_mismatched_types() {
fn is_invalid_column_type(err: Error) -> bool {
match err {
Error::InvalidColumnType(_, _) => true,
Error::InvalidColumnType(..) => true,
_ => false,
}
}

View File

@@ -1,7 +1,6 @@
//! `ToSql` and `FromSql` implementation for JSON `Value`.
use serde_json;
use self::serde_json::Value;
use serde_json::Value;
use crate::types::{FromSql, FromSqlError, FromSqlResult, ToSql, ToSqlOutput, ValueRef};
use crate::Result;
@@ -17,7 +16,7 @@ impl ToSql for Value {
impl FromSql for Value {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
ValueRef::Text(s) => serde_json::from_str(s),
ValueRef::Text(s) => serde_json::from_slice(s),
ValueRef::Blob(b) => serde_json::from_slice(b),
_ => return Err(FromSqlError::InvalidType),
}
@@ -27,9 +26,9 @@ impl FromSql for Value {
#[cfg(test)]
mod test {
use super::serde_json;
use crate::types::ToSql;
use crate::{Connection, NO_PARAMS};
use serde_json;
fn checked_memory_handle() -> Connection {
let db = Connection::open_in_memory().unwrap();

View File

@@ -36,8 +36,8 @@ impl FromSql for time::Timespec {
#[cfg(test)]
mod test {
use super::time;
use crate::{Connection, Result, NO_PARAMS};
use time;
fn checked_memory_handle() -> Connection {
let db = Connection::open_in_memory().unwrap();

View File

@@ -14,10 +14,12 @@ impl ToSql for Url {
impl FromSql for Url {
fn column_result(value: ValueRef<'_>) -> FromSqlResult<Self> {
match value {
ValueRef::Text(s) => Url::parse(s),
_ => return Err(FromSqlError::InvalidType),
ValueRef::Text(s) => {
let s = std::str::from_utf8(s).map_err(|e| FromSqlError::Other(Box::new(e)))?;
Url::parse(s).map_err(|e| FromSqlError::Other(Box::new(e)))
}
_ => Err(FromSqlError::InvalidType),
}
.map_err(|err| FromSqlError::Other(Box::new(err)))
}
}

View File

@@ -96,6 +96,18 @@ impl From<Vec<u8>> for Value {
}
}
impl<T> From<Option<T>> for Value
where
T: Into<Value>,
{
fn from(v: Option<T>) -> Value {
match v {
Some(x) => x.into(),
None => Value::Null,
}
}
}
impl Value {
pub fn data_type(&self) -> Type {
match *self {

View File

@@ -14,7 +14,7 @@ pub enum ValueRef<'a> {
/// The value is a floating point number.
Real(f64),
/// The value is a text string.
Text(&'a str),
Text(&'a [u8]),
/// The value is a blob of data
Blob(&'a [u8]),
}
@@ -54,7 +54,9 @@ impl<'a> ValueRef<'a> {
/// `Err(Error::InvalidColumnType)`.
pub fn as_str(&self) -> FromSqlResult<&'a str> {
match *self {
ValueRef::Text(t) => Ok(t),
ValueRef::Text(t) => {
std::str::from_utf8(t).map_err(|e| FromSqlError::Other(Box::new(e)))
}
_ => Err(FromSqlError::InvalidType),
}
}
@@ -75,7 +77,10 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Null => Value::Null,
ValueRef::Integer(i) => Value::Integer(i),
ValueRef::Real(r) => Value::Real(r),
ValueRef::Text(s) => Value::Text(s.to_string()),
ValueRef::Text(s) => {
let s = std::str::from_utf8(s).expect("invalid UTF-8");
Value::Text(s.to_string())
}
ValueRef::Blob(b) => Value::Blob(b.to_vec()),
}
}
@@ -83,7 +88,7 @@ impl From<ValueRef<'_>> for Value {
impl<'a> From<&'a str> for ValueRef<'a> {
fn from(s: &str) -> ValueRef<'_> {
ValueRef::Text(s)
ValueRef::Text(s.as_bytes())
}
}
@@ -99,12 +104,24 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Null => ValueRef::Null,
Value::Integer(i) => ValueRef::Integer(i),
Value::Real(r) => ValueRef::Real(r),
Value::Text(ref s) => ValueRef::Text(s),
Value::Text(ref s) => ValueRef::Text(s.as_bytes()),
Value::Blob(ref b) => ValueRef::Blob(b),
}
}
}
impl<'a, T> From<Option<T>> for ValueRef<'a>
where
T: Into<ValueRef<'a>>,
{
fn from(s: Option<T>) -> ValueRef<'a> {
match s {
Some(x) => x.into(),
None => ValueRef::Null,
}
}
}
#[cfg(any(feature = "functions", feature = "session", feature = "vtab"))]
impl<'a> ValueRef<'a> {
pub(crate) unsafe fn from_value(value: *mut crate::ffi::sqlite3_value) -> ValueRef<'a> {
@@ -125,10 +142,7 @@ impl<'a> ValueRef<'a> {
);
let s = CStr::from_ptr(text as *const c_char);
// sqlite3_value_text returns UTF8 data, so our unwrap here should be fine.
let s = s
.to_str()
.expect("sqlite3_value_text returned invalid UTF-8");
let s = s.to_bytes();
ValueRef::Text(s)
}
ffi::SQLITE_BLOB => {

View File

@@ -17,6 +17,7 @@ struct UnlockNotification {
}
#[cfg(feature = "unlock_notify")]
#[allow(clippy::mutex_atomic)]
impl UnlockNotification {
fn new() -> UnlockNotification {
UnlockNotification {

View File

@@ -35,7 +35,7 @@ pub fn load_module(conn: &Connection) -> Result<()> {
conn.create_module("rarray", &ARRAY_MODULE, aux)
}
lazy_static! {
lazy_static::lazy_static! {
static ref ARRAY_MODULE: Module<ArrayTab> = eponymous_only_module::<ArrayTab>(1);
}

View File

@@ -32,7 +32,7 @@ pub fn load_module(conn: &Connection) -> Result<()> {
conn.create_module("csv", &CSV_MODULE, aux)
}
lazy_static! {
lazy_static::lazy_static! {
static ref CSV_MODULE: Module<CSVTab> = read_only_module::<CSVTab>(1);
}
@@ -235,7 +235,7 @@ impl VTab for CSVTab {
schema = Some(sql);
}
Ok((schema.unwrap().to_owned(), vtab))
Ok((schema.unwrap(), vtab))
}
// Only a forward full table scan is supported.
@@ -338,8 +338,7 @@ impl VTabCursor for CSVTabCursor {
impl From<csv::Error> for Error {
fn from(err: csv::Error) -> Error {
use std::error::Error as StdError;
Error::ModuleError(String::from(err.description()))
Error::ModuleError(err.to_string())
}
}

View File

@@ -67,8 +67,17 @@ pub struct Module<T: VTab> {
phantom: PhantomData<T>,
}
unsafe impl<T: VTab> Send for Module<T> {}
unsafe impl<T: VTab> Sync for Module<T> {}
// Used as a trailing initializer for sqlite3_module -- this way we avoid having
// the build fail if buildtime_bindgen is on, our bindings have
// `sqlite3_module::xShadowName`, but vtab_v3 wasn't specified.
fn zeroed_module() -> ffi::sqlite3_module {
// This is safe, as bindgen-generated structs are allowed to be zeroed.
unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
}
/// Create a read-only virtual table implementation.
///
/// Step 2 of [Creating New Virtual Table Implementations](https://sqlite.org/vtab.html#creating_new_virtual_table_implementations).
@@ -99,8 +108,7 @@ pub fn read_only_module<T: CreateVTab>(version: c_int) -> Module<T> {
xSavepoint: None,
xRelease: None,
xRollbackTo: None,
#[cfg(any(feature = "bundled", feature = "vtab_v3"))]
xShadowName: None,
..zeroed_module()
};
Module {
base: ffi_module,
@@ -139,8 +147,7 @@ pub fn eponymous_only_module<T: VTab>(version: c_int) -> Module<T> {
xSavepoint: None,
xRelease: None,
xRollbackTo: None,
#[cfg(any(feature = "bundled", feature = "vtab_v3"))]
xShadowName: None,
..zeroed_module()
};
Module {
base: ffi_module,
@@ -161,7 +168,11 @@ impl VTabConnection {
///
/// You should not need to use this function. If you do need to, please
/// [open an issue on the rusqlite repository](https://github.com/jgallagher/rusqlite/issues) and describe
/// your use case. This function is unsafe because it gives you raw access
/// your use case.
///
/// # Safety
///
/// This function is unsafe because it gives you raw access
/// to the SQLite connection, and what you do with it could impact the
/// safety of this `Connection`.
pub unsafe fn handle(&mut self) -> *mut ffi::sqlite3 {
@@ -233,16 +244,46 @@ pub trait CreateVTab: VTab {
}
}
bitflags! {
#[doc = "Index constraint operator."]
#[repr(C)]
pub struct IndexConstraintOp: ::std::os::raw::c_uchar {
const SQLITE_INDEX_CONSTRAINT_EQ = 2;
const SQLITE_INDEX_CONSTRAINT_GT = 4;
const SQLITE_INDEX_CONSTRAINT_LE = 8;
const SQLITE_INDEX_CONSTRAINT_LT = 16;
const SQLITE_INDEX_CONSTRAINT_GE = 32;
const SQLITE_INDEX_CONSTRAINT_MATCH = 64;
///Index constraint operator.
#[derive(Debug, PartialEq)]
#[allow(non_snake_case, non_camel_case_types)]
pub enum IndexConstraintOp {
SQLITE_INDEX_CONSTRAINT_EQ,
SQLITE_INDEX_CONSTRAINT_GT,
SQLITE_INDEX_CONSTRAINT_LE,
SQLITE_INDEX_CONSTRAINT_LT,
SQLITE_INDEX_CONSTRAINT_GE,
SQLITE_INDEX_CONSTRAINT_MATCH,
SQLITE_INDEX_CONSTRAINT_LIKE, // 3.10.0
SQLITE_INDEX_CONSTRAINT_GLOB, // 3.10.0
SQLITE_INDEX_CONSTRAINT_REGEXP, // 3.10.0
SQLITE_INDEX_CONSTRAINT_NE, // 3.21.0
SQLITE_INDEX_CONSTRAINT_ISNOT, // 3.21.0
SQLITE_INDEX_CONSTRAINT_ISNOTNULL, // 3.21.0
SQLITE_INDEX_CONSTRAINT_ISNULL, // 3.21.0
SQLITE_INDEX_CONSTRAINT_IS, // 3.21.0
SQLITE_INDEX_CONSTRAINT_FUNCTION(u8), // 3.25.0
}
impl From<u8> for IndexConstraintOp {
fn from(code: u8) -> IndexConstraintOp {
match code {
2 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_EQ,
4 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GT,
8 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LE,
16 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LT,
32 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GE,
64 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_MATCH,
65 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_LIKE,
66 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_GLOB,
67 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_REGEXP,
68 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_NE,
69 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOT,
70 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNOTNULL,
71 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_ISNULL,
72 => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_IS,
v => IndexConstraintOp::SQLITE_INDEX_CONSTRAINT_FUNCTION(v),
}
}
}
@@ -304,8 +345,8 @@ impl IndexInfo {
}
}
/// Estimated number of rows returned
#[cfg(feature = "bundled")] // SQLite >= 3.8.2
/// Estimated number of rows returned.
#[cfg(feature = "modern_sqlite")] // SQLite >= 3.8.2
pub fn set_estimated_rows(&mut self, estimated_rows: i64) {
unsafe {
(*self.0).estimatedRows = estimated_rows;
@@ -345,7 +386,7 @@ impl IndexConstraint<'_> {
/// Constraint operator
pub fn operator(&self) -> IndexConstraintOp {
IndexConstraintOp::from_bits_truncate(self.0.op)
IndexConstraintOp::from(self.0.op)
}
/// True if this constraint is usable
@@ -472,7 +513,9 @@ impl Values<'_> {
}
FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
#[cfg(feature = "i128_blob")]
FromSqlError::InvalidI128Size(_) => Error::InvalidColumnType(idx, value.data_type()),
FromSqlError::InvalidI128Size(_) => {
Error::InvalidColumnType(idx, idx.to_string(), value.data_type())
}
#[cfg(feature = "uuid")]
FromSqlError::InvalidUuidSize(_) => {
Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
@@ -643,9 +686,7 @@ unsafe extern "C" fn rust_create<T>(
where
T: CreateVTab,
{
use std::error::Error as StdError;
use std::ffi::CStr;
use std::slice;
let mut conn = VTabConnection(db);
let aux = aux as *mut T::Aux;
@@ -664,12 +705,12 @@ where
ffi::SQLITE_OK
} else {
let err = error_from_sqlite_code(rc, None);
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
rc
}
}
Err(err) => {
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
ffi::SQLITE_ERROR
}
},
@@ -680,7 +721,7 @@ where
err.extended_code
}
Err(err) => {
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -697,9 +738,7 @@ unsafe extern "C" fn rust_connect<T>(
where
T: VTab,
{
use std::error::Error as StdError;
use std::ffi::CStr;
use std::slice;
let mut conn = VTabConnection(db);
let aux = aux as *mut T::Aux;
@@ -718,12 +757,12 @@ where
ffi::SQLITE_OK
} else {
let err = error_from_sqlite_code(rc, None);
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
rc
}
}
Err(err) => {
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
ffi::SQLITE_ERROR
}
},
@@ -734,7 +773,7 @@ where
err.extended_code
}
Err(err) => {
*err_msg = mprintf(err.description());
*err_msg = mprintf(&err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -747,7 +786,6 @@ unsafe extern "C" fn rust_best_index<T>(
where
T: VTab,
{
use std::error::Error as StdError;
let vt = vtab as *mut T;
let mut idx_info = IndexInfo(info);
match (*vt).best_index(&mut idx_info) {
@@ -759,7 +797,7 @@ where
err.extended_code
}
Err(err) => {
set_err_msg(vtab, err.description());
set_err_msg(vtab, &err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -781,7 +819,6 @@ unsafe extern "C" fn rust_destroy<T>(vtab: *mut ffi::sqlite3_vtab) -> c_int
where
T: CreateVTab,
{
use std::error::Error as StdError;
if vtab.is_null() {
return ffi::SQLITE_OK;
}
@@ -798,7 +835,7 @@ where
err.extended_code
}
Err(err) => {
set_err_msg(vtab, err.description());
set_err_msg(vtab, &err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -811,7 +848,6 @@ unsafe extern "C" fn rust_open<T>(
where
T: VTab,
{
use std::error::Error as StdError;
let vt = vtab as *mut T;
match (*vt).open() {
Ok(cursor) => {
@@ -826,7 +862,7 @@ where
err.extended_code
}
Err(err) => {
set_err_msg(vtab, err.description());
set_err_msg(vtab, &err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -852,7 +888,6 @@ where
C: VTabCursor,
{
use std::ffi::CStr;
use std::slice;
use std::str;
let idx_name = if idx_str.is_null() {
None
@@ -915,7 +950,6 @@ where
/// Virtual table cursors can set an error message by assigning a string to
/// `zErrMsg`.
unsafe fn cursor_error<T>(cursor: *mut ffi::sqlite3_vtab_cursor, result: Result<T>) -> c_int {
use std::error::Error as StdError;
match result {
Ok(_) => ffi::SQLITE_OK,
Err(Error::SqliteFailure(err, s)) => {
@@ -925,7 +959,7 @@ unsafe fn cursor_error<T>(cursor: *mut ffi::sqlite3_vtab_cursor, result: Result<
err.extended_code
}
Err(err) => {
set_err_msg((*cursor).pVtab, err.description());
set_err_msg((*cursor).pVtab, &err.to_string());
ffi::SQLITE_ERROR
}
}
@@ -943,7 +977,6 @@ unsafe fn set_err_msg(vtab: *mut ffi::sqlite3_vtab, err_msg: &str) {
/// To raise an error, the `column` method should use this method to set the
/// error message and return the error code.
unsafe fn result_error<T>(ctx: *mut ffi::sqlite3_context, result: Result<T>) -> c_int {
use std::error::Error as StdError;
match result {
Ok(_) => ffi::SQLITE_OK,
Err(Error::SqliteFailure(err, s)) => {
@@ -965,7 +998,7 @@ unsafe fn result_error<T>(ctx: *mut ffi::sqlite3_context, result: Result<T>) ->
}
Err(err) => {
ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_ERROR);
if let Ok(cstr) = str_to_cstring(err.description()) {
if let Ok(cstr) = str_to_cstring(&err.to_string()) {
ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
}
ffi::SQLITE_ERROR
@@ -985,7 +1018,7 @@ fn mprintf(err_msg: &str) -> *mut c_char {
pub mod array;
#[cfg(feature = "csvtab")]
pub mod csvtab;
#[cfg(feature = "bundled")]
#[cfg(feature = "series")]
pub mod series; // SQLite >= 3.9.0
#[cfg(test)]

View File

@@ -18,7 +18,7 @@ pub fn load_module(conn: &Connection) -> Result<()> {
conn.create_module("generate_series", &SERIES_MODULE, aux)
}
lazy_static! {
lazy_static::lazy_static! {
static ref SERIES_MODULE: Module<SeriesTab> = eponymous_only_module::<SeriesTab>(1);
}
@@ -28,7 +28,7 @@ const SERIES_COLUMN_START: c_int = 1;
const SERIES_COLUMN_STOP: c_int = 2;
const SERIES_COLUMN_STEP: c_int = 3;
bitflags! {
bitflags::bitflags! {
#[repr(C)]
struct QueryPlanFlags: ::std::os::raw::c_int {
// start = $value -- constraint exists