diff --git a/.travis.yml b/.travis.yml index 00e02c0..4f51f65 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,7 @@ script: - cargo build --features sqlcipher - cargo build --features "bundled sqlcipher" - cargo test - - cargo test --features "backup blob" + - cargo test --features "backup blob extra_check" - cargo test --features "collation functions" - cargo test --features "hooks limits" - cargo test --features load_extension diff --git a/Cargo.toml b/Cargo.toml index cd4a96d..a4fcb13 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,12 +48,13 @@ csvtab = ["csv", "vtab"] # pointer passing interfaces: 3.20.0 array = ["vtab"] # session extension: 3.13.0 -session = ["libsqlite3-sys/session", "hooks"] +#session = ["libsqlite3-sys/session", "hooks"] # window functions: 3.25.0 window = ["functions"] # 3.9.0 series = ["vtab"] - +# check for invalid query. +extra_check = [] [dependencies] time = "0.1.0" diff --git a/src/collation.rs b/src/collation.rs index 00c1cd1..1ca2b33 100644 --- a/src/collation.rs +++ b/src/collation.rs @@ -125,7 +125,9 @@ impl InnerConnection { str::from_utf8_unchecked(c_slice) }; callback(&conn, collation_name) - }).is_err() { + }) + .is_err() + { return; // FIXME How ? } } diff --git a/src/error.rs b/src/error.rs index 8156a81..ef22f61 100644 --- a/src/error.rs +++ b/src/error.rs @@ -102,7 +102,8 @@ pub enum Error { GetAuxWrongType, /// Error when the SQL contains multiple statements. - MultipleStatement,} + MultipleStatement, +} impl PartialEq for Error { fn eq(&self, other: &Error) -> bool { diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 1e47a1b..75b1242 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -1,6 +1,6 @@ use std::ffi::CString; use std::mem::MaybeUninit; -use std::os::raw::c_int; +use std::os::raw::{c_char, c_int}; #[cfg(feature = "load_extension")] use std::path::Path; use std::ptr; @@ -178,8 +178,6 @@ impl InnerConnection { #[cfg(feature = "load_extension")] pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { - use std::os::raw::c_char; - let dylib_str = super::path_to_cstring(dylib_path)?; unsafe { let mut errmsg = MaybeUninit::uninit(); @@ -217,7 +215,7 @@ impl InnerConnection { pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result> { let mut c_stmt = MaybeUninit::uninit(); let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; - let mut c_tail = ptr::null(); + let mut c_tail = MaybeUninit::uninit(); let r = unsafe { if cfg!(feature = "unlock_notify") { let mut rc; @@ -227,7 +225,7 @@ impl InnerConnection { c_sql, len, c_stmt.as_mut_ptr(), - &mut c_tail, + c_tail.as_mut_ptr(), ); if !unlock_notify::is_locked(self.db, rc) { break; @@ -239,16 +237,22 @@ impl InnerConnection { } rc } else { - ffi::sqlite3_prepare_v2(self.db(), c_sql, len, c_stmt.as_mut_ptr(), &mut c_tail) + ffi::sqlite3_prepare_v2( + self.db(), + c_sql, + len, + c_stmt.as_mut_ptr(), + c_tail.as_mut_ptr(), + ) } }; + self.decode_result(r)?; + let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() }; - if !c_tail.is_null() && unsafe { *c_tail == 0 } { - // '\0' when there is no ';' at the end - c_tail = ptr::null(); // TODO ignore spaces, comments, ... at the end - } - self.decode_result(r) - .map(|_| Statement::new(conn, RawStatement::new(c_stmt, c_tail))) + let c_tail: *const c_char = unsafe { c_tail.assume_init() }; + // TODO ignore spaces, comments, ... at the end + let tail = !c_tail.is_null() && unsafe { c_tail != c_sql.offset(len as isize) }; + Ok(Statement::new(conn, RawStatement::new(c_stmt, tail))) } pub fn changes(&mut self) -> usize { diff --git a/src/lib.rs b/src/lib.rs index 0156241..056dfd9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -896,7 +896,8 @@ mod test { ) .expect("create temp db"); - let mut db1 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap(); + let mut db1 = + Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap(); let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).unwrap(); db1.busy_timeout(Duration::from_millis(0)).unwrap(); @@ -1066,7 +1067,8 @@ mod test { .execute( "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", NO_PARAMS, - ).unwrap_err(); + ) + .unwrap_err(); match err { Error::MultipleStatement => (), _ => panic!("Unexpected error: {}", err), diff --git a/src/raw_statement.rs b/src/raw_statement.rs index 2ee5f61..1a7a8e7 100644 --- a/src/raw_statement.rs +++ b/src/raw_statement.rs @@ -2,15 +2,15 @@ use super::ffi; use super::unlock_notify; use super::StatementStatus; use std::ffi::CStr; -use std::os::raw::{c_char, c_int}; +use std::os::raw::c_int; use std::ptr; // Private newtype for raw sqlite3_stmts that finalize themselves when dropped. #[derive(Debug)] -pub struct RawStatement(*mut ffi::sqlite3_stmt, *const c_char); +pub struct RawStatement(*mut ffi::sqlite3_stmt, bool); impl RawStatement { - pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: *const c_char) -> RawStatement { + pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement { RawStatement(stmt, tail) } @@ -131,7 +131,7 @@ impl RawStatement { } pub fn has_tail(&self) -> bool { - !self.1.is_null() + self.1 } } diff --git a/src/statement.rs b/src/statement.rs index 8b2fccc..7032316 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -511,6 +511,7 @@ impl Statement<'_> { } fn execute_with_bound_parameters(&mut self) -> Result { + self.check_update()?; let r = self.stmt.step(); self.stmt.reset(); match r { @@ -527,7 +528,7 @@ impl Statement<'_> { } fn finalize_(&mut self) -> Result<()> { - let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); + let mut stmt = RawStatement::new(ptr::null_mut(), false); mem::swap(&mut stmt, &mut self.stmt); self.conn.decode_result(stmt.finalize()) } @@ -547,6 +548,30 @@ impl Statement<'_> { Ok(()) } + #[cfg(all(feature = "bundled", feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + if self.column_count() > 0 || self.stmt.readonly() { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(all(not(feature = "bundled"), feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + if self.column_count() > 0 { + return Err(Error::ExecuteReturnedResults); + } + Ok(()) + } + + #[cfg(not(feature = "extra_check"))] + #[inline] + fn check_update(&self) -> Result<()> { + Ok(()) + } + /// Returns a string containing the SQL text of prepared statement with /// bound parameters expanded. #[cfg(feature = "bundled")] @@ -574,6 +599,7 @@ impl Statement<'_> { self.stmt.get_status(status, true) } + #[cfg(feature = "extra_check")] pub(crate) fn check_no_tail(&self) -> Result<()> { if self.stmt.has_tail() { Err(Error::MultipleStatement) @@ -581,11 +607,17 @@ impl Statement<'_> { Ok(()) } } + + #[cfg(not(feature = "extra_check"))] + #[inline] + pub(crate) fn check_no_tail(&self) -> Result<()> { + Ok(()) + } } impl Into for Statement<'_> { fn into(mut self) -> RawStatement { - let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); + let mut stmt = RawStatement::new(ptr::null_mut(), false); mem::swap(&mut stmt, &mut self.stmt); stmt }