From 12f26e78b360d79cb7761e1f9c24752dcae6b7b8 Mon Sep 17 00:00:00 2001 From: gwenn Date: Sat, 2 Jan 2016 12:13:37 +0100 Subject: [PATCH] Introduce RowIndex trait (like in rust-postgres) --- src/backup.rs | 12 ++++---- src/error.rs | 7 +++++ src/functions.rs | 28 +++++++++--------- src/lib.rs | 71 ++++++++++++++++++++++++++++++++++++-------- src/named_params.rs | 8 ++--- src/types.rs | 72 ++++++++++++++++++++++----------------------- 6 files changed, 126 insertions(+), 72 deletions(-) diff --git a/src/backup.rs b/src/backup.rs index 047f4fd..e55c8ee 100644 --- a/src/backup.rs +++ b/src/backup.rs @@ -315,7 +315,7 @@ mod test { backup.step(-1).unwrap(); } - let the_answer = dst.query_row("SELECT x FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT x FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42, the_answer); src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); @@ -325,7 +325,7 @@ mod test { backup.run_to_completion(5, Duration::from_millis(250), None).unwrap(); } - let the_answer = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42 + 43, the_answer); } @@ -350,7 +350,7 @@ mod test { backup.step(-1).unwrap(); } - let the_answer = dst.query_row("SELECT x FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT x FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42, the_answer); src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); @@ -364,7 +364,7 @@ mod test { backup.run_to_completion(5, Duration::from_millis(250), None).unwrap(); } - let the_answer = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42 + 43, the_answer); } @@ -390,7 +390,7 @@ mod test { backup.step(-1).unwrap(); } - let the_answer = dst.query_row("SELECT x FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT x FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42, the_answer); src.execute_batch("INSERT INTO foo VALUES(43)").unwrap(); @@ -404,7 +404,7 @@ mod test { backup.run_to_completion(5, Duration::from_millis(250), None).unwrap(); } - let the_answer = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::(0)).unwrap(); + let the_answer: i64 = dst.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap(); assert_eq!(42 + 43, the_answer); } } diff --git a/src/error.rs b/src/error.rs index 9670800..bf128eb 100644 --- a/src/error.rs +++ b/src/error.rs @@ -48,6 +48,10 @@ pub enum Error { /// for the statement. InvalidColumnIndex(c_int), + /// Error when the value of a named column is requested, but no column matches the name + /// for the statement. + InvalidColumnName(String), + /// Error when the value of a particular column is requested, but the type of the result in /// that column cannot be converted to the requested Rust type. InvalidColumnType, @@ -91,6 +95,7 @@ impl fmt::Display for Error { &Error::QueryReturnedNoRows => write!(f, "Query returned no rows"), &Error::GetFromStaleRow => write!(f, "Attempted to get a value from a stale row"), &Error::InvalidColumnIndex(i) => write!(f, "Invalid column index: {}", i), + &Error::InvalidColumnName(ref name) => write!(f, "Invalid column name: {}", name), &Error::InvalidColumnType => write!(f, "Invalid column type"), #[cfg(feature = "functions")] @@ -116,6 +121,7 @@ impl error::Error for Error { &Error::QueryReturnedNoRows => "query returned no rows", &Error::GetFromStaleRow => "attempted to get a value from a stale row", &Error::InvalidColumnIndex(_) => "invalid column index", + &Error::InvalidColumnName(_) => "invalid column name", &Error::InvalidColumnType => "invalid column type", #[cfg(feature = "functions")] @@ -138,6 +144,7 @@ impl error::Error for Error { &Error::QueryReturnedNoRows => None, &Error::GetFromStaleRow => None, &Error::InvalidColumnIndex(_) => None, + &Error::InvalidColumnName(_) => None, &Error::InvalidColumnType => None, #[cfg(feature = "functions")] diff --git a/src/functions.rs b/src/functions.rs index 1def91c..d9fecb0 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -43,8 +43,8 @@ //! let db = Connection::open_in_memory().unwrap(); //! add_regexp_function(&db).unwrap(); //! -//! let is_match = db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", &[], -//! |row| row.get::(0)).unwrap(); +//! let is_match: bool = db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", &[], +//! |row| row.get(0)).unwrap(); //! //! assert!(is_match); //! } @@ -354,7 +354,7 @@ impl Connection { /// Ok(value / 2f64) /// })); /// - /// let six_halved = try!(db.query_row("SELECT halve(6)", &[], |r| r.get::(0))); + /// let six_halved: f64 = try!(db.query_row("SELECT halve(6)", &[], |r| r.get(0))); /// assert_eq!(six_halved, 3f64); /// Ok(()) /// } @@ -485,7 +485,7 @@ mod test { fn test_function_half() { let db = Connection::open_in_memory().unwrap(); db.create_scalar_function("half", 1, true, half).unwrap(); - let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); + let result: Result = db.query_row("SELECT half(6)", &[], |r| r.get(0)); assert_eq!(3f64, result.unwrap()); } @@ -494,11 +494,11 @@ mod test { fn test_remove_function() { let db = Connection::open_in_memory().unwrap(); db.create_scalar_function("half", 1, true, half).unwrap(); - let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); + let result: Result = db.query_row("SELECT half(6)", &[], |r| r.get(0)); assert_eq!(3f64, result.unwrap()); db.remove_function("half", 1).unwrap(); - let result = db.query_row("SELECT half(6)", &[], |r| r.get::(0)); + let result: Result = db.query_row("SELECT half(6)", &[], |r| r.get(0)); assert!(result.is_err()); } @@ -546,15 +546,15 @@ mod test { END;").unwrap(); db.create_scalar_function("regexp", 2, true, regexp_with_auxilliary).unwrap(); - let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", + let result: Result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], - |r| r.get::(0)); + |r| r.get(0)); assert_eq!(true, result.unwrap()); - let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + let result: Result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", &[], - |r| r.get::(0)); + |r| r.get(0)); assert_eq!(2, result.unwrap()); } @@ -596,15 +596,15 @@ mod test { Ok(regex.is_match(&text)) }).unwrap(); - let result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", + let result: Result = db.query_row("SELECT regexp('l.s[aeiouy]', 'lisa')", &[], - |r| r.get::(0)); + |r| r.get(0)); assert_eq!(true, result.unwrap()); - let result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", + let result: Result = db.query_row("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1", &[], - |r| r.get::(0)); + |r| r.get(0)); assert_eq!(2, result.unwrap()); } diff --git a/src/lib.rs b/src/lib.rs index 0b77690..4cde86c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -728,6 +728,23 @@ impl<'conn> Statement<'conn> { cols } + /// Returns the column index in the result set for a given column name. + /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. + /// + /// # Failure + /// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`. + pub fn column_index(&self, name: &str) -> Result { + let bytes = name.as_bytes(); + let n = self.column_count; + for i in 0..n { + let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; + if bytes == slice.to_bytes() { + return Ok(i); + } + } + Err(Error::InvalidColumnName(String::from(name))) + } + /// Execute the prepared statement. /// /// On success, returns the number of rows that were changed or inserted or deleted (via @@ -1078,9 +1095,11 @@ impl<'stmt> Row<'stmt> { /// /// ## Failure /// + /// Panics if the underlying SQLite column type is not a valid type as a source for `T`. + /// /// Panics if `idx` is outside the range of columns in the returned query or if this row /// is stale. - pub fn get(&self, idx: c_int) -> T { + pub fn get(&self, idx: I) -> T { self.get_checked(idx).unwrap() } @@ -1088,19 +1107,22 @@ impl<'stmt> Row<'stmt> { /// /// ## Failure /// - /// Returns a `SQLITE_MISMATCH`-coded `Error` if the underlying SQLite column + /// Returns an `Error::InvalidColumnType` if the underlying SQLite column /// type is not a valid type as a source for `T`. /// - /// Returns a `SQLITE_MISUSE`-coded `Error` if `idx` is outside the valid column range - /// for this row or if this row is stale. - pub fn get_checked(&self, idx: c_int) -> Result { + /// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid column range + /// for this row. + /// + /// Returns an `Error::InvalidColumnName` if `idx` is not a valid column name + /// for this row. + /// + /// Returns an `Error::GetFromStaleRow` if this row is stale. + pub fn get_checked(&self, idx: I) -> Result { if self.row_idx != self.current_row.get() { return Err(Error::GetFromStaleRow); } unsafe { - if idx < 0 || idx >= self.stmt.column_count { - return Err(Error::InvalidColumnIndex(idx)); - } + let idx = try!(idx.idx(self.stmt)); if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) { FromSql::column_result(self.stmt.stmt, idx) @@ -1111,6 +1133,31 @@ impl<'stmt> Row<'stmt> { } } +/// A trait implemented by types that can index into columns of a row. +pub trait RowIndex { + /// Returns the index of the appropriate column, or `None` if no such + /// column exists. + fn idx(&self, stmt: &Statement) -> Result; +} + +impl RowIndex for i32 { + #[inline] + fn idx(&self, stmt: &Statement) -> Result { + if *self < 0 || *self >= stmt.column_count { + Err(Error::InvalidColumnIndex(*self)) + } else { + Ok(*self) + } + } +} + +impl<'a> RowIndex for &'a str { + #[inline] + fn idx(&self, stmt: &Statement) -> Result { + stmt.column_index(*self) + } +} + #[cfg(test)] mod test { extern crate libsqlite3_sys as ffi; @@ -1149,7 +1196,7 @@ mod test { let path_string = path.to_str().unwrap(); let db = Connection::open(&path_string).unwrap(); - let the_answer = db.query_row("SELECT x FROM foo", &[], |r| r.get::(0)); + let the_answer: Result = db.query_row("SELECT x FROM foo", &[], |r| r.get(0)); assert_eq!(42i64, the_answer.unwrap()); } @@ -1306,10 +1353,10 @@ mod test { db.execute_batch(sql).unwrap(); assert_eq!(10i64, - db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::(0)) + db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)) .unwrap()); - let result = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get::(0)); + let result: Result = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get(0)); match result.unwrap_err() { Error::QueryReturnedNoRows => (), err => panic!("Unexpected error {}", err), @@ -1343,7 +1390,7 @@ mod test { assert_eq!(2i32, second.get(0)); - match first.get_checked::(0).unwrap_err() { + match first.get_checked::(0).unwrap_err() { Error::GetFromStaleRow => (), err => panic!("Unexpected error {}", err), } diff --git a/src/named_params.rs b/src/named_params.rs index dc880ca..4f55dd8 100644 --- a/src/named_params.rs +++ b/src/named_params.rs @@ -198,8 +198,8 @@ mod test { let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)").unwrap(); stmt.execute_named(&[(":x", &"one")]).unwrap(); - let result = db.query_row("SELECT y FROM test WHERE x = 'one'", &[], - |row| row.get::>(0)).unwrap(); + let result: Option = db.query_row("SELECT y FROM test WHERE x = 'one'", &[], + |row| row.get(0)).unwrap(); assert!(result.is_none()); } @@ -213,8 +213,8 @@ mod test { stmt.execute_named(&[(":x", &"one")]).unwrap(); stmt.execute_named(&[(":y", &"two")]).unwrap(); - let result = db.query_row("SELECT x FROM test WHERE y = 'two'", &[], - |row| row.get::(0)).unwrap(); + let result: String = db.query_row("SELECT x FROM test WHERE y = 'two'", &[], + |row| row.get(0)).unwrap(); assert_eq!(result, "one"); } } diff --git a/src/types.rs b/src/types.rs index c33afa8..e80f267 100644 --- a/src/types.rs +++ b/src/types.rs @@ -389,53 +389,53 @@ mod test { let row = rows.next().unwrap().unwrap(); // check the correct types come back as expected - assert_eq!(vec![1, 2], row.get_checked::>(0).unwrap()); - assert_eq!("text", row.get_checked::(1).unwrap()); - assert_eq!(1, row.get_checked::(2).unwrap()); - assert_eq!(1.5, row.get_checked::(3).unwrap()); - assert!(row.get_checked::>(4).unwrap().is_none()); - assert!(row.get_checked::>(4).unwrap().is_none()); - assert!(row.get_checked::>(4).unwrap().is_none()); + assert_eq!(vec![1, 2], row.get_checked::>(0).unwrap()); + assert_eq!("text", row.get_checked::(1).unwrap()); + assert_eq!(1, row.get_checked::(2).unwrap()); + assert_eq!(1.5, row.get_checked::(3).unwrap()); + assert!(row.get_checked::>(4).unwrap().is_none()); + assert!(row.get_checked::>(4).unwrap().is_none()); + assert!(row.get_checked::>(4).unwrap().is_none()); // check some invalid types // 0 is actually a blob (Vec) - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(0).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(0).err().unwrap())); // 1 is actually a text (String) - assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(1).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(1).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(1).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(1).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(1).err().unwrap())); // 2 is actually an integer - assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(2).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(2).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(2).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(2).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(2).err().unwrap())); // 3 is actually a float (c_double) - assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(3).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(3).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(3).err().unwrap())); // 4 is actually NULL - assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::>(4).err().unwrap())); - assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::>(4).err().unwrap())); + assert!(is_invalid_column_type(row.get_checked::(4).err().unwrap())); } }