Check that only one statement is provided (#397)

Connection.execute
Connection.execute_named
Connection.quer_row
Connection.quer_row_named
This commit is contained in:
gwenn 2018-10-28 10:28:19 +01:00
parent ebf98b4241
commit 77cb50e000
4 changed files with 56 additions and 13 deletions

View File

@ -95,6 +95,9 @@ pub enum Error {
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
#[allow(dead_code)] #[allow(dead_code)]
ModuleError(String), ModuleError(String),
/// Error when the SQL contains multiple statements.
MultipleStatement,
} }
impl From<str::Utf8Error> for Error { impl From<str::Utf8Error> for Error {
@ -155,6 +158,7 @@ impl fmt::Display for Error {
Error::InvalidQuery => write!(f, "Query is not read-only"), Error::InvalidQuery => write!(f, "Query is not read-only"),
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
Error::ModuleError(ref desc) => write!(f, "{}", desc), Error::ModuleError(ref desc) => write!(f, "{}", desc),
Error::MultipleStatement => write!(f, "Multiple statements provided"),
} }
} }
} }
@ -192,6 +196,7 @@ impl error::Error for Error {
Error::InvalidQuery => "query is not read-only", Error::InvalidQuery => "query is not read-only",
#[cfg(feature = "vtab")] #[cfg(feature = "vtab")]
Error::ModuleError(ref desc) => desc, Error::ModuleError(ref desc) => desc,
Error::MultipleStatement => "multiple statements provided",
} }
} }
@ -211,7 +216,8 @@ impl error::Error for Error {
| Error::InvalidColumnType(_, _) | Error::InvalidColumnType(_, _)
| Error::InvalidPath(_) | Error::InvalidPath(_)
| Error::StatementChangedRows(_) | Error::StatementChangedRows(_)
| Error::InvalidQuery => None, | Error::InvalidQuery
| Error::MultipleStatement => None,
#[cfg(feature = "functions")] #[cfg(feature = "functions")]
Error::InvalidFunctionParameterType(_, _) => None, Error::InvalidFunctionParameterType(_, _) => None,

View File

@ -331,7 +331,8 @@ impl Connection {
P: IntoIterator, P: IntoIterator,
P::Item: ToSql, P::Item: ToSql,
{ {
self.prepare(sql).and_then(|mut stmt| stmt.execute(params)) self.prepare(sql)
.and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params)))
} }
/// Convenience method to prepare and execute a single SQL statement with /// Convenience method to prepare and execute a single SQL statement with
@ -357,8 +358,10 @@ impl Connection {
/// Will return `Err` if `sql` cannot be converted to a C-compatible string /// Will return `Err` if `sql` cannot be converted to a C-compatible string
/// or if the underlying SQLite call fails. /// or if the underlying SQLite call fails.
pub fn execute_named(&self, sql: &str, params: &[(&str, &ToSql)]) -> Result<usize> { pub fn execute_named(&self, sql: &str, params: &[(&str, &ToSql)]) -> Result<usize> {
self.prepare(sql) self.prepare(sql).and_then(|mut stmt| {
.and_then(|mut stmt| stmt.execute_named(params)) stmt.check_no_tail()
.and_then(|_| stmt.execute_named(params))
})
} }
/// Get the SQLite rowid of the most recent successful INSERT. /// Get the SQLite rowid of the most recent successful INSERT.
@ -399,6 +402,7 @@ impl Connection {
F: FnOnce(&Row) -> T, F: FnOnce(&Row) -> T,
{ {
let mut stmt = try!(self.prepare(sql)); let mut stmt = try!(self.prepare(sql));
try!(stmt.check_no_tail());
stmt.query_row(params, f) stmt.query_row(params, f)
} }
@ -417,6 +421,7 @@ impl Connection {
F: FnOnce(&Row) -> T, F: FnOnce(&Row) -> T,
{ {
let mut stmt = try!(self.prepare(sql)); let mut stmt = try!(self.prepare(sql));
try!(stmt.check_no_tail());
let mut rows = try!(stmt.query_named(params)); let mut rows = try!(stmt.query_named(params));
rows.get_expected_row().map(|r| f(&r)) rows.get_expected_row().map(|r| f(&r))
@ -455,6 +460,7 @@ impl Connection {
E: convert::From<Error>, E: convert::From<Error>,
{ {
let mut stmt = try!(self.prepare(sql)); let mut stmt = try!(self.prepare(sql));
try!(stmt.check_no_tail());
let mut rows = try!(stmt.query(params)); let mut rows = try!(stmt.query(params));
rows.get_expected_row().map_err(E::from).and_then(|r| f(&r)) rows.get_expected_row().map_err(E::from).and_then(|r| f(&r))
@ -975,6 +981,7 @@ impl InnerConnection {
let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() }; let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() };
let c_sql = try!(str_to_cstring(sql)); let c_sql = try!(str_to_cstring(sql));
let len_with_nul = (sql.len() + 1) as c_int; let len_with_nul = (sql.len() + 1) as c_int;
let mut c_tail = ptr::null();
let r = unsafe { let r = unsafe {
if cfg!(feature = "unlock_notify") { if cfg!(feature = "unlock_notify") {
let mut rc; let mut rc;
@ -984,7 +991,7 @@ impl InnerConnection {
c_sql.as_ptr(), c_sql.as_ptr(),
len_with_nul, len_with_nul,
&mut c_stmt, &mut c_stmt,
ptr::null_mut(), &mut c_tail,
); );
if !unlock_notify::is_locked(self.db, rc) { if !unlock_notify::is_locked(self.db, rc) {
break; break;
@ -1001,12 +1008,16 @@ impl InnerConnection {
c_sql.as_ptr(), c_sql.as_ptr(),
len_with_nul, len_with_nul,
&mut c_stmt, &mut c_stmt,
ptr::null_mut(), &mut c_tail,
) )
} }
}; };
if !c_tail.is_null() && unsafe { *c_tail == 0 } {
// '\0' when there is no ';' at the end
c_tail = ptr::null(); // TODO ignore spaces, comments, ... at the end
}
self.decode_result(r) self.decode_result(r)
.map(|_| Statement::new(conn, RawStatement::new(c_stmt))) .map(|_| Statement::new(conn, RawStatement::new(c_stmt, c_tail)))
} }
fn changes(&mut self) -> usize { fn changes(&mut self) -> usize {
@ -1289,6 +1300,20 @@ mod test {
} }
} }
#[test]
fn test_execute_multiple() {
let db = checked_memory_handle();
let err = db
.execute(
"CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)",
NO_PARAMS,
).unwrap_err();
match err {
Error::MultipleStatement => (),
_ => panic!("Unexpected error: {}", err),
}
}
#[test] #[test]
fn test_prepare_column_names() { fn test_prepare_column_names() {
let db = checked_memory_handle(); let db = checked_memory_handle();

View File

@ -1,16 +1,16 @@
use super::ffi; use super::ffi;
use super::unlock_notify; use super::unlock_notify;
use std::ffi::CStr; use std::ffi::CStr;
use std::os::raw::c_int; use std::os::raw::{c_char, c_int};
use std::ptr; use std::ptr;
// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. // Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
#[derive(Debug)] #[derive(Debug)]
pub struct RawStatement(*mut ffi::sqlite3_stmt); pub struct RawStatement(*mut ffi::sqlite3_stmt, *const c_char);
impl RawStatement { impl RawStatement {
pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: *const c_char) -> RawStatement {
RawStatement(stmt) RawStatement(stmt, tail)
} }
pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt {
@ -100,6 +100,10 @@ impl RawStatement {
} }
} }
} }
pub fn has_tail(&self) -> bool {
!self.1.is_null()
}
} }
impl Drop for RawStatement { impl Drop for RawStatement {

View File

@ -540,7 +540,7 @@ impl<'conn> Statement<'conn> {
} }
fn finalize_(&mut self) -> Result<()> { fn finalize_(&mut self) -> Result<()> {
let mut stmt = RawStatement::new(ptr::null_mut()); let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null());
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
self.conn.decode_result(stmt.finalize()) self.conn.decode_result(stmt.finalize())
} }
@ -570,11 +570,19 @@ impl<'conn> Statement<'conn> {
.map(|s| str::from_utf8_unchecked(s.to_bytes())) .map(|s| str::from_utf8_unchecked(s.to_bytes()))
} }
} }
pub(crate) fn check_no_tail(&self) -> Result<()> {
if self.stmt.has_tail() {
Err(Error::MultipleStatement)
} else {
Ok(())
}
}
} }
impl<'conn> Into<RawStatement> for Statement<'conn> { impl<'conn> Into<RawStatement> for Statement<'conn> {
fn into(mut self) -> RawStatement { fn into(mut self) -> RawStatement {
let mut stmt = RawStatement::new(ptr::null_mut()); let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null());
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
stmt stmt
} }