Merge pull request #421 from gwenn/tail

Check that only one statement is provided (#397)
This commit is contained in:
gwenn 2019-08-31 09:54:27 +02:00 committed by GitHub
commit e85ebedb58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 72 additions and 16 deletions

View File

@ -100,6 +100,9 @@ pub enum Error {
/// of a different type than what had been stored using `Context::set_aux`. /// of a different type than what had been stored using `Context::set_aux`.
#[cfg(feature = "functions")] #[cfg(feature = "functions")]
GetAuxWrongType, GetAuxWrongType,
/// Error when the SQL contains multiple statements.
MultipleStatement,
} }
impl PartialEq for Error { impl PartialEq for Error {
@ -244,6 +247,7 @@ impl fmt::Display for Error {
Error::UnwindingPanic => write!(f, "unwinding panic"), Error::UnwindingPanic => write!(f, "unwinding panic"),
#[cfg(feature = "functions")] #[cfg(feature = "functions")]
Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"),
Error::MultipleStatement => write!(f, "Multiple statements provided"),
} }
} }
} }
@ -285,6 +289,7 @@ impl error::Error for Error {
Error::UnwindingPanic => "unwinding panic", Error::UnwindingPanic => "unwinding panic",
#[cfg(feature = "functions")] #[cfg(feature = "functions")]
Error::GetAuxWrongType => "get_aux called with wrong type", Error::GetAuxWrongType => "get_aux called with wrong type",
Error::MultipleStatement => "multiple statements provided",
} }
} }
@ -304,7 +309,8 @@ impl error::Error for Error {
| Error::InvalidColumnType(_, _, _) | Error::InvalidColumnType(_, _, _)
| Error::InvalidPath(_) | Error::InvalidPath(_)
| Error::StatementChangedRows(_) | Error::StatementChangedRows(_)
| Error::InvalidQuery => None, | Error::InvalidQuery
| Error::MultipleStatement => None,
#[cfg(feature = "functions")] #[cfg(feature = "functions")]
Error::InvalidFunctionParameterType(_, _) => None, Error::InvalidFunctionParameterType(_, _) => None,

View File

@ -1,6 +1,6 @@
use std::ffi::CString; use std::ffi::CString;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::os::raw::c_int; use std::os::raw::{c_char, c_int};
#[cfg(feature = "load_extension")] #[cfg(feature = "load_extension")]
use std::path::Path; use std::path::Path;
use std::ptr; use std::ptr;
@ -178,8 +178,6 @@ impl InnerConnection {
#[cfg(feature = "load_extension")] #[cfg(feature = "load_extension")]
pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> {
use std::os::raw::c_char;
let dylib_str = super::path_to_cstring(dylib_path)?; let dylib_str = super::path_to_cstring(dylib_path)?;
unsafe { unsafe {
let mut errmsg = MaybeUninit::uninit(); let mut errmsg = MaybeUninit::uninit();
@ -217,6 +215,7 @@ impl InnerConnection {
pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> { pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> {
let mut c_stmt = MaybeUninit::uninit(); let mut c_stmt = MaybeUninit::uninit();
let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?;
let mut c_tail = MaybeUninit::uninit();
let r = unsafe { let r = unsafe {
if cfg!(feature = "unlock_notify") { if cfg!(feature = "unlock_notify") {
let mut rc; let mut rc;
@ -226,7 +225,7 @@ impl InnerConnection {
c_sql, c_sql,
len, len,
c_stmt.as_mut_ptr(), c_stmt.as_mut_ptr(),
ptr::null_mut(), c_tail.as_mut_ptr(),
); );
if !unlock_notify::is_locked(self.db, rc) { if !unlock_notify::is_locked(self.db, rc) {
break; break;
@ -238,12 +237,22 @@ impl InnerConnection {
} }
rc rc
} else { } else {
ffi::sqlite3_prepare_v2(self.db(), c_sql, len, c_stmt.as_mut_ptr(), ptr::null_mut()) ffi::sqlite3_prepare_v2(
self.db(),
c_sql,
len,
c_stmt.as_mut_ptr(),
c_tail.as_mut_ptr(),
)
} }
}; };
self.decode_result(r)?;
let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() }; let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() };
self.decode_result(r) let c_tail: *const c_char = unsafe { c_tail.assume_init() };
.map(|_| Statement::new(conn, RawStatement::new(c_stmt))) // TODO ignore spaces, comments, ... at the end
let tail = !c_tail.is_null() && unsafe { c_tail != c_sql.offset(len as isize) };
Ok(Statement::new(conn, RawStatement::new(c_stmt, tail)))
} }
pub fn changes(&mut self) -> usize { pub fn changes(&mut self) -> usize {

View File

@ -449,7 +449,8 @@ impl Connection {
P: IntoIterator, P: IntoIterator,
P::Item: ToSql, P::Item: ToSql,
{ {
self.prepare(sql).and_then(|mut stmt| stmt.execute(params)) self.prepare(sql)
.and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params)))
} }
/// Convenience method to prepare and execute a single SQL statement with /// Convenience method to prepare and execute a single SQL statement with
@ -475,8 +476,10 @@ impl Connection {
/// Will return `Err` if `sql` cannot be converted to a C-compatible string /// Will return `Err` if `sql` cannot be converted to a C-compatible string
/// or if the underlying SQLite call fails. /// or if the underlying SQLite call fails.
pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result<usize> { pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result<usize> {
self.prepare(sql) self.prepare(sql).and_then(|mut stmt| {
.and_then(|mut stmt| stmt.execute_named(params)) stmt.check_no_tail()
.and_then(|_| stmt.execute_named(params))
})
} }
/// Get the SQLite rowid of the most recent successful INSERT. /// Get the SQLite rowid of the most recent successful INSERT.
@ -521,6 +524,7 @@ impl Connection {
F: FnOnce(&Row<'_>) -> Result<T>, F: FnOnce(&Row<'_>) -> Result<T>,
{ {
let mut stmt = self.prepare(sql)?; let mut stmt = self.prepare(sql)?;
stmt.check_no_tail()?;
stmt.query_row(params, f) stmt.query_row(params, f)
} }
@ -543,6 +547,7 @@ impl Connection {
F: FnOnce(&Row<'_>) -> Result<T>, F: FnOnce(&Row<'_>) -> Result<T>,
{ {
let mut stmt = self.prepare(sql)?; let mut stmt = self.prepare(sql)?;
stmt.check_no_tail()?;
stmt.query_row_named(params, f) stmt.query_row_named(params, f)
} }
@ -579,6 +584,7 @@ impl Connection {
E: convert::From<Error>, E: convert::From<Error>,
{ {
let mut stmt = self.prepare(sql)?; let mut stmt = self.prepare(sql)?;
stmt.check_no_tail()?;
let mut rows = stmt.query(params)?; let mut rows = stmt.query(params)?;
rows.get_expected_row().map_err(E::from).and_then(|r| f(&r)) rows.get_expected_row().map_err(E::from).and_then(|r| f(&r))
@ -1054,6 +1060,22 @@ mod test {
} }
} }
#[test]
#[cfg(feature = "extra_check")]
fn test_execute_multiple() {
let db = checked_memory_handle();
let err = db
.execute(
"CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)",
NO_PARAMS,
)
.unwrap_err();
match err {
Error::MultipleStatement => (),
_ => panic!("Unexpected error: {}", err),
}
}
#[test] #[test]
fn test_prepare_column_names() { fn test_prepare_column_names() {
let db = checked_memory_handle(); let db = checked_memory_handle();

View File

@ -7,11 +7,11 @@ use std::ptr;
// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. // Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
#[derive(Debug)] #[derive(Debug)]
pub struct RawStatement(*mut ffi::sqlite3_stmt); pub struct RawStatement(*mut ffi::sqlite3_stmt, bool);
impl RawStatement { impl RawStatement {
pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement {
RawStatement(stmt) RawStatement(stmt, tail)
} }
pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt {
@ -129,6 +129,10 @@ impl RawStatement {
assert!(!self.0.is_null()); assert!(!self.0.is_null());
unsafe { ffi::sqlite3_stmt_status(self.0, status as i32, reset as i32) } unsafe { ffi::sqlite3_stmt_status(self.0, status as i32, reset as i32) }
} }
pub fn has_tail(&self) -> bool {
self.1
}
} }
impl Drop for RawStatement { impl Drop for RawStatement {

View File

@ -528,7 +528,7 @@ impl Statement<'_> {
} }
fn finalize_(&mut self) -> Result<()> { fn finalize_(&mut self) -> Result<()> {
let mut stmt = RawStatement::new(ptr::null_mut()); let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
self.conn.decode_result(stmt.finalize()) self.conn.decode_result(stmt.finalize())
} }
@ -598,11 +598,26 @@ impl Statement<'_> {
pub fn reset_status(&self, status: StatementStatus) -> i32 { pub fn reset_status(&self, status: StatementStatus) -> i32 {
self.stmt.get_status(status, true) self.stmt.get_status(status, true)
} }
#[cfg(feature = "extra_check")]
pub(crate) fn check_no_tail(&self) -> Result<()> {
if self.stmt.has_tail() {
Err(Error::MultipleStatement)
} else {
Ok(())
}
}
#[cfg(not(feature = "extra_check"))]
#[inline]
pub(crate) fn check_no_tail(&self) -> Result<()> {
Ok(())
}
} }
impl Into<RawStatement> for Statement<'_> { impl Into<RawStatement> for Statement<'_> {
fn into(mut self) -> RawStatement { fn into(mut self) -> RawStatement {
let mut stmt = RawStatement::new(ptr::null_mut()); let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
stmt stmt
} }