From 77cb50e00016ad09f3f84c94fb28366e98f34c51 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 28 Oct 2018 10:28:19 +0100 Subject: [PATCH] Check that only one statement is provided (#397) Connection.execute Connection.execute_named Connection.quer_row Connection.quer_row_named --- src/error.rs | 8 +++++++- src/lib.rs | 37 +++++++++++++++++++++++++++++++------ src/raw_statement.rs | 12 ++++++++---- src/statement.rs | 12 ++++++++++-- 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/src/error.rs b/src/error.rs index 27674f8..5708458 100644 --- a/src/error.rs +++ b/src/error.rs @@ -95,6 +95,9 @@ pub enum Error { #[cfg(feature = "vtab")] #[allow(dead_code)] ModuleError(String), + + /// Error when the SQL contains multiple statements. + MultipleStatement, } impl From for Error { @@ -155,6 +158,7 @@ impl fmt::Display for Error { Error::InvalidQuery => write!(f, "Query is not read-only"), #[cfg(feature = "vtab")] Error::ModuleError(ref desc) => write!(f, "{}", desc), + Error::MultipleStatement => write!(f, "Multiple statements provided"), } } } @@ -192,6 +196,7 @@ impl error::Error for Error { Error::InvalidQuery => "query is not read-only", #[cfg(feature = "vtab")] Error::ModuleError(ref desc) => desc, + Error::MultipleStatement => "multiple statements provided", } } @@ -211,7 +216,8 @@ impl error::Error for Error { | Error::InvalidColumnType(_, _) | Error::InvalidPath(_) | Error::StatementChangedRows(_) - | Error::InvalidQuery => None, + | Error::InvalidQuery + | Error::MultipleStatement => None, #[cfg(feature = "functions")] Error::InvalidFunctionParameterType(_, _) => None, diff --git a/src/lib.rs b/src/lib.rs index c1d7c80..3ce06ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -331,7 +331,8 @@ impl Connection { P: IntoIterator, P::Item: ToSql, { - self.prepare(sql).and_then(|mut stmt| stmt.execute(params)) + self.prepare(sql) + .and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params))) } /// Convenience method to prepare and execute a single SQL statement with @@ -357,8 +358,10 @@ impl Connection { /// Will return `Err` if `sql` cannot be converted to a C-compatible string /// or if the underlying SQLite call fails. pub fn execute_named(&self, sql: &str, params: &[(&str, &ToSql)]) -> Result { - self.prepare(sql) - .and_then(|mut stmt| stmt.execute_named(params)) + self.prepare(sql).and_then(|mut stmt| { + stmt.check_no_tail() + .and_then(|_| stmt.execute_named(params)) + }) } /// Get the SQLite rowid of the most recent successful INSERT. @@ -399,6 +402,7 @@ impl Connection { F: FnOnce(&Row) -> T, { let mut stmt = try!(self.prepare(sql)); + try!(stmt.check_no_tail()); stmt.query_row(params, f) } @@ -417,6 +421,7 @@ impl Connection { F: FnOnce(&Row) -> T, { let mut stmt = try!(self.prepare(sql)); + try!(stmt.check_no_tail()); let mut rows = try!(stmt.query_named(params)); rows.get_expected_row().map(|r| f(&r)) @@ -455,6 +460,7 @@ impl Connection { E: convert::From, { let mut stmt = try!(self.prepare(sql)); + try!(stmt.check_no_tail()); let mut rows = try!(stmt.query(params)); rows.get_expected_row().map_err(E::from).and_then(|r| f(&r)) @@ -975,6 +981,7 @@ impl InnerConnection { let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() }; let c_sql = try!(str_to_cstring(sql)); let len_with_nul = (sql.len() + 1) as c_int; + let mut c_tail = ptr::null(); let r = unsafe { if cfg!(feature = "unlock_notify") { let mut rc; @@ -984,7 +991,7 @@ impl InnerConnection { c_sql.as_ptr(), len_with_nul, &mut c_stmt, - ptr::null_mut(), + &mut c_tail, ); if !unlock_notify::is_locked(self.db, rc) { break; @@ -1001,12 +1008,16 @@ impl InnerConnection { c_sql.as_ptr(), len_with_nul, &mut c_stmt, - ptr::null_mut(), + &mut c_tail, ) } }; + if !c_tail.is_null() && unsafe { *c_tail == 0 } { + // '\0' when there is no ';' at the end + c_tail = ptr::null(); // TODO ignore spaces, comments, ... at the end + } self.decode_result(r) - .map(|_| Statement::new(conn, RawStatement::new(c_stmt))) + .map(|_| Statement::new(conn, RawStatement::new(c_stmt, c_tail))) } fn changes(&mut self) -> usize { @@ -1289,6 +1300,20 @@ mod test { } } + #[test] + fn test_execute_multiple() { + let db = checked_memory_handle(); + let err = db + .execute( + "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", + NO_PARAMS, + ).unwrap_err(); + match err { + Error::MultipleStatement => (), + _ => panic!("Unexpected error: {}", err), + } + } + #[test] fn test_prepare_column_names() { let db = checked_memory_handle(); diff --git a/src/raw_statement.rs b/src/raw_statement.rs index b9d405d..2fe2d3c 100644 --- a/src/raw_statement.rs +++ b/src/raw_statement.rs @@ -1,16 +1,16 @@ use super::ffi; use super::unlock_notify; use std::ffi::CStr; -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; use std::ptr; // Private newtype for raw sqlite3_stmts that finalize themselves when dropped. #[derive(Debug)] -pub struct RawStatement(*mut ffi::sqlite3_stmt); +pub struct RawStatement(*mut ffi::sqlite3_stmt, *const c_char); impl RawStatement { - pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { - RawStatement(stmt) + pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: *const c_char) -> RawStatement { + RawStatement(stmt, tail) } pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { @@ -100,6 +100,10 @@ impl RawStatement { } } } + + pub fn has_tail(&self) -> bool { + !self.1.is_null() + } } impl Drop for RawStatement { diff --git a/src/statement.rs b/src/statement.rs index f1c7765..89cd1cb 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -540,7 +540,7 @@ impl<'conn> Statement<'conn> { } fn finalize_(&mut self) -> Result<()> { - let mut stmt = RawStatement::new(ptr::null_mut()); + let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); mem::swap(&mut stmt, &mut self.stmt); self.conn.decode_result(stmt.finalize()) } @@ -570,11 +570,19 @@ impl<'conn> Statement<'conn> { .map(|s| str::from_utf8_unchecked(s.to_bytes())) } } + + pub(crate) fn check_no_tail(&self) -> Result<()> { + if self.stmt.has_tail() { + Err(Error::MultipleStatement) + } else { + Ok(()) + } + } } impl<'conn> Into for Statement<'conn> { fn into(mut self) -> RawStatement { - let mut stmt = RawStatement::new(ptr::null_mut()); + let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); mem::swap(&mut stmt, &mut self.stmt); stmt }