From 37cfcf470bbb20600fe4872209bc7e19c5743829 Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Mon, 4 May 2015 21:47:20 -0400 Subject: [PATCH] Add SqliteRow::get_checked, which performs basic SQLite column type checking. --- src/lib.rs | 24 ++++++++++++ src/types.rs | 107 ++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 125 insertions(+), 6 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 13007cc..62aae8d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -797,6 +797,30 @@ impl<'stmt> SqliteRow<'stmt> { self.get_opt(idx).unwrap() } + /// Get the value of a particular column of the result row. + /// + /// ## Failure + /// + /// Returns a `SQLITE_MISMATCH`-coded `SqliteError` if the underlying SQLite column + /// type is not a valid type as a source for `T`. + /// + /// Panics if `idx` is outside the range of columns in the returned query or if this row + /// is stale. + pub fn get_checked(&self, idx: c_int) -> SqliteResult { + let valid_column_type = unsafe { + T::column_has_valid_sqlite_type(self.stmt.stmt, idx) + }; + + if valid_column_type { + Ok(self.get(idx)) + } else { + Err(SqliteError{ + code: ffi::SQLITE_MISMATCH, + message: "Invalid column type".to_string(), + }) + } + } + /// Attempt to get the value of a particular column of the result row. /// /// ## Failure diff --git a/src/types.rs b/src/types.rs index f902a59..29eb72b 100644 --- a/src/types.rs +++ b/src/types.rs @@ -62,6 +62,9 @@ use super::ffi; use super::{SqliteResult, SqliteError, str_to_cstring}; pub use ffi::sqlite3_stmt as sqlite3_stmt; +pub use ffi::sqlite3_column_type as sqlite3_column_type; + +pub use ffi::{SQLITE_INTEGER, SQLITE_FLOAT, SQLITE_TEXT, SQLITE_BLOB, SQLITE_NULL}; const SQLITE_DATETIME_FMT: &'static str = "%Y-%m-%d %H:%M:%S"; @@ -73,6 +76,14 @@ pub trait ToSql { /// A trait for types that can be created from a SQLite value. pub trait FromSql { unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult; + + /// FromSql types can implement this method and use sqlite3_column_type to check that + /// the type reported by SQLite matches a type suitable for Self. This method is used + /// by `SqliteRow::get_checked` to confirm that the column contains a valid type before + /// attempting to retrieve the value. + unsafe fn column_has_valid_sqlite_type(_: *mut sqlite3_stmt, _: c_int) -> bool { + true + } } macro_rules! raw_to_impl( @@ -161,18 +172,22 @@ impl ToSql for Null { } macro_rules! raw_from_impl( - ($t:ty, $f:ident) => ( + ($t:ty, $f:ident, $c:expr) => ( impl FromSql for $t { unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult<$t> { Ok(ffi::$f(stmt, col)) } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + sqlite3_column_type(stmt, col) == $c + } } ) ); -raw_from_impl!(c_int, sqlite3_column_int); -raw_from_impl!(i64, sqlite3_column_int64); -raw_from_impl!(c_double, sqlite3_column_double); +raw_from_impl!(c_int, sqlite3_column_int, ffi::SQLITE_INTEGER); +raw_from_impl!(i64, sqlite3_column_int64, ffi::SQLITE_INTEGER); +raw_from_impl!(c_double, sqlite3_column_double, ffi::SQLITE_FLOAT); impl FromSql for String { unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult { @@ -187,6 +202,10 @@ impl FromSql for String { .map_err(|e| { SqliteError{code: 0, message: e.to_string()} }) } } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + sqlite3_column_type(stmt, col) == ffi::SQLITE_TEXT + } } impl FromSql for Vec { @@ -202,6 +221,10 @@ impl FromSql for Vec { Ok(from_raw_parts(mem::transmute(c_blob), len).to_vec()) } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + sqlite3_column_type(stmt, col) == ffi::SQLITE_BLOB + } } impl FromSql for time::Timespec { @@ -216,26 +239,37 @@ impl FromSql for time::Timespec { }) }) } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + String::column_has_valid_sqlite_type(stmt, col) + } } impl FromSql for Option { unsafe fn column_result(stmt: *mut sqlite3_stmt, col: c_int) -> SqliteResult> { - if ffi::sqlite3_column_type(stmt, col) == ffi::SQLITE_NULL { + if sqlite3_column_type(stmt, col) == ffi::SQLITE_NULL { Ok(None) } else { FromSql::column_result(stmt, col).map(|t| Some(t)) } } + + unsafe fn column_has_valid_sqlite_type(stmt: *mut sqlite3_stmt, col: c_int) -> bool { + sqlite3_column_type(stmt, col) == ffi::SQLITE_NULL || + T::column_has_valid_sqlite_type(stmt, col) + } } #[cfg(test)] mod test { use SqliteConnection; + use ffi; use super::time; + use libc::{c_int, c_double}; fn checked_memory_handle() -> SqliteConnection { let db = SqliteConnection::open_in_memory().unwrap(); - db.execute_batch("CREATE TABLE foo (b BLOB, t TEXT)").unwrap(); + db.execute_batch("CREATE TABLE foo (b BLOB, t TEXT, i INTEGER, f FLOAT, n)").unwrap(); db } @@ -297,4 +331,65 @@ mod test { assert!(s2.is_none()); assert_eq!(b, b2); } + + #[test] + fn test_mismatched_types() { + let db = checked_memory_handle(); + + db.execute("INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", &[]).unwrap(); + + let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo").unwrap(); + let mut rows = stmt.query(&[]).unwrap(); + + let row = rows.next().unwrap().unwrap(); + + // check the correct types come back as expected + assert_eq!(vec![1,2], row.get_checked::>(0).unwrap()); + assert_eq!("text", row.get_checked::(1).unwrap()); + assert_eq!(1, row.get_checked::(2).unwrap()); + assert_eq!(1.5, row.get_checked::(3).unwrap()); + assert!(row.get_checked::>(4).unwrap().is_none()); + assert!(row.get_checked::>(4).unwrap().is_none()); + assert!(row.get_checked::>(4).unwrap().is_none()); + + // check some invalid types + + // 0 is actually a blob (Vec) + assert_eq!(row.get_checked::(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(0).err().unwrap().code, ffi::SQLITE_MISMATCH); + + // 1 is actually a text (String) + assert_eq!(row.get_checked::(1).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(1).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(1).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(1).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(1).err().unwrap().code, ffi::SQLITE_MISMATCH); + + // 2 is actually an integer + assert_eq!(row.get_checked::(2).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(2).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(2).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(2).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(2).err().unwrap().code, ffi::SQLITE_MISMATCH); + + // 3 is actually a float (c_double) + assert_eq!(row.get_checked::(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(3).err().unwrap().code, ffi::SQLITE_MISMATCH); + + // 4 is actually NULL + assert_eq!(row.get_checked::(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::>(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + assert_eq!(row.get_checked::(4).err().unwrap().code, ffi::SQLITE_MISMATCH); + } }