diff --git a/src/error.rs b/src/error.rs index b932465..ef22f61 100644 --- a/src/error.rs +++ b/src/error.rs @@ -100,6 +100,9 @@ pub enum Error { /// of a different type than what had been stored using `Context::set_aux`. #[cfg(feature = "functions")] GetAuxWrongType, + + /// Error when the SQL contains multiple statements. + MultipleStatement, } impl PartialEq for Error { @@ -244,6 +247,7 @@ impl fmt::Display for Error { Error::UnwindingPanic => write!(f, "unwinding panic"), #[cfg(feature = "functions")] Error::GetAuxWrongType => write!(f, "get_aux called with wrong type"), + Error::MultipleStatement => write!(f, "Multiple statements provided"), } } } @@ -285,6 +289,7 @@ impl error::Error for Error { Error::UnwindingPanic => "unwinding panic", #[cfg(feature = "functions")] Error::GetAuxWrongType => "get_aux called with wrong type", + Error::MultipleStatement => "multiple statements provided", } } @@ -304,7 +309,8 @@ impl error::Error for Error { | Error::InvalidColumnType(_, _, _) | Error::InvalidPath(_) | Error::StatementChangedRows(_) - | Error::InvalidQuery => None, + | Error::InvalidQuery + | Error::MultipleStatement => None, #[cfg(feature = "functions")] Error::InvalidFunctionParameterType(_, _) => None, diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 327747a..75b1242 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -1,6 +1,6 @@ use std::ffi::CString; use std::mem::MaybeUninit; -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; #[cfg(feature = "load_extension")] use std::path::Path; use std::ptr; @@ -178,8 +178,6 @@ impl InnerConnection { #[cfg(feature = "load_extension")] pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { - use std::os::raw::c_char; - let dylib_str = super::path_to_cstring(dylib_path)?; unsafe { let mut errmsg = MaybeUninit::uninit(); @@ -217,6 +215,7 @@ impl InnerConnection { pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result> { let mut c_stmt = MaybeUninit::uninit(); let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; + let mut c_tail = MaybeUninit::uninit(); let r = unsafe { if cfg!(feature = "unlock_notify") { let mut rc; @@ -226,7 +225,7 @@ impl InnerConnection { c_sql, len, c_stmt.as_mut_ptr(), - ptr::null_mut(), + c_tail.as_mut_ptr(), ); if !unlock_notify::is_locked(self.db, rc) { break; @@ -238,12 +237,22 @@ impl InnerConnection { } rc } else { - ffi::sqlite3_prepare_v2(self.db(), c_sql, len, c_stmt.as_mut_ptr(), ptr::null_mut()) + ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + c_stmt.as_mut_ptr(), + c_tail.as_mut_ptr(), + ) } }; + self.decode_result(r)?; + let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() }; - self.decode_result(r) - .map(|_| Statement::new(conn, RawStatement::new(c_stmt))) + let c_tail: *const c_char = unsafe { c_tail.assume_init() }; + // TODO ignore spaces, comments, ... at the end + let tail = !c_tail.is_null() && unsafe { c_tail != c_sql.offset(len as isize) }; + Ok(Statement::new(conn, RawStatement::new(c_stmt, tail))) } pub fn changes(&mut self) -> usize { diff --git a/src/lib.rs b/src/lib.rs index 7bbdb24..7e302d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -449,7 +449,8 @@ impl Connection { P: IntoIterator, P::Item: ToSql, { - self.prepare(sql).and_then(|mut stmt| stmt.execute(params)) + self.prepare(sql) + .and_then(|mut stmt| stmt.check_no_tail().and_then(|_| stmt.execute(params))) } /// Convenience method to prepare and execute a single SQL statement with @@ -475,8 +476,10 @@ impl Connection { /// Will return `Err` if `sql` cannot be converted to a C-compatible string /// or if the underlying SQLite call fails. pub fn execute_named(&self, sql: &str, params: &[(&str, &dyn ToSql)]) -> Result { - self.prepare(sql) - .and_then(|mut stmt| stmt.execute_named(params)) + self.prepare(sql).and_then(|mut stmt| { + stmt.check_no_tail() + .and_then(|_| stmt.execute_named(params)) + }) } /// Get the SQLite rowid of the most recent successful INSERT. @@ -521,6 +524,7 @@ impl Connection { F: FnOnce(&Row<'_>) -> Result, { let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; stmt.query_row(params, f) } @@ -543,6 +547,7 @@ impl Connection { F: FnOnce(&Row<'_>) -> Result, { let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; stmt.query_row_named(params, f) } @@ -579,6 +584,7 @@ impl Connection { E: convert::From, { let mut stmt = self.prepare(sql)?; + stmt.check_no_tail()?; let mut rows = stmt.query(params)?; rows.get_expected_row().map_err(E::from).and_then(|r| f(&r)) @@ -1054,6 +1060,22 @@ mod test { } } + #[test] + #[cfg(feature = "extra_check")] + fn test_execute_multiple() { + let db = checked_memory_handle(); + let err = db + .execute( + "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", + NO_PARAMS, + ) + .unwrap_err(); + match err { + Error::MultipleStatement => (), + _ => panic!("Unexpected error: {}", err), + } + } + #[test] fn test_prepare_column_names() { let db = checked_memory_handle(); diff --git a/src/raw_statement.rs b/src/raw_statement.rs index 9e1d74e..1a7a8e7 100644 --- a/src/raw_statement.rs +++ b/src/raw_statement.rs @@ -7,11 +7,11 @@ use std::ptr; // Private newtype for raw sqlite3_stmts that finalize themselves when dropped. #[derive(Debug)] -pub struct RawStatement(*mut ffi::sqlite3_stmt); +pub struct RawStatement(*mut ffi::sqlite3_stmt, bool); impl RawStatement { - pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { - RawStatement(stmt) + pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement { + RawStatement(stmt, tail) } pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { @@ -129,6 +129,10 @@ impl RawStatement { assert!(!self.0.is_null()); unsafe { ffi::sqlite3_stmt_status(self.0, status as i32, reset as i32) } } + + pub fn has_tail(&self) -> bool { + self.1 + } } impl Drop for RawStatement { diff --git a/src/statement.rs b/src/statement.rs index 683ba9a..7032316 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -528,7 +528,7 @@ impl Statement<'_> { } fn finalize_(&mut self) -> Result<()> { - let mut stmt = RawStatement::new(ptr::null_mut()); + let mut stmt = RawStatement::new(ptr::null_mut(), false); mem::swap(&mut stmt, &mut self.stmt); self.conn.decode_result(stmt.finalize()) } @@ -598,11 +598,26 @@ impl Statement<'_> { pub fn reset_status(&self, status: StatementStatus) -> i32 { self.stmt.get_status(status, true) } + + #[cfg(feature = "extra_check")] + pub(crate) fn check_no_tail(&self) -> Result<()> { + if self.stmt.has_tail() { + Err(Error::MultipleStatement) + } else { + Ok(()) + } + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + pub(crate) fn check_no_tail(&self) -> Result<()> { + Ok(()) + } } impl Into for Statement<'_> { fn into(mut self) -> RawStatement { - let mut stmt = RawStatement::new(ptr::null_mut()); + let mut stmt = RawStatement::new(ptr::null_mut(), false); mem::swap(&mut stmt, &mut self.stmt); stmt }