Introduce RowIndex trait (like in rust-postgres)

This commit is contained in:
gwenn
2016-01-02 12:13:37 +01:00
parent 38cf8d597b
commit 12f26e78b3
6 changed files with 126 additions and 72 deletions

View File

@@ -728,6 +728,23 @@ impl<'conn> Statement<'conn> {
cols
}
/// Returns the column index in the result set for a given column name.
/// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next.
///
/// # Failure
/// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`.
pub fn column_index(&self, name: &str) -> Result<i32> {
let bytes = name.as_bytes();
let n = self.column_count;
for i in 0..n {
let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) };
if bytes == slice.to_bytes() {
return Ok(i);
}
}
Err(Error::InvalidColumnName(String::from(name)))
}
/// Execute the prepared statement.
///
/// On success, returns the number of rows that were changed or inserted or deleted (via
@@ -1078,9 +1095,11 @@ impl<'stmt> Row<'stmt> {
///
/// ## Failure
///
/// Panics if the underlying SQLite column type is not a valid type as a source for `T`.
///
/// Panics if `idx` is outside the range of columns in the returned query or if this row
/// is stale.
pub fn get<T: FromSql>(&self, idx: c_int) -> T {
pub fn get<I: RowIndex, T: FromSql>(&self, idx: I) -> T {
self.get_checked(idx).unwrap()
}
@@ -1088,19 +1107,22 @@ impl<'stmt> Row<'stmt> {
///
/// ## Failure
///
/// Returns a `SQLITE_MISMATCH`-coded `Error` if the underlying SQLite column
/// Returns an `Error::InvalidColumnType` if the underlying SQLite column
/// type is not a valid type as a source for `T`.
///
/// Returns a `SQLITE_MISUSE`-coded `Error` if `idx` is outside the valid column range
/// for this row or if this row is stale.
pub fn get_checked<T: FromSql>(&self, idx: c_int) -> Result<T> {
/// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid column range
/// for this row.
///
/// Returns an `Error::InvalidColumnName` if `idx` is not a valid column name
/// for this row.
///
/// Returns an `Error::GetFromStaleRow` if this row is stale.
pub fn get_checked<I: RowIndex, T: FromSql>(&self, idx: I) -> Result<T> {
if self.row_idx != self.current_row.get() {
return Err(Error::GetFromStaleRow);
}
unsafe {
if idx < 0 || idx >= self.stmt.column_count {
return Err(Error::InvalidColumnIndex(idx));
}
let idx = try!(idx.idx(self.stmt));
if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) {
FromSql::column_result(self.stmt.stmt, idx)
@@ -1111,6 +1133,31 @@ impl<'stmt> Row<'stmt> {
}
}
/// A trait implemented by types that can index into columns of a row.
pub trait RowIndex {
/// Returns the index of the appropriate column, or `None` if no such
/// column exists.
fn idx(&self, stmt: &Statement) -> Result<i32>;
}
impl RowIndex for i32 {
#[inline]
fn idx(&self, stmt: &Statement) -> Result<i32> {
if *self < 0 || *self >= stmt.column_count {
Err(Error::InvalidColumnIndex(*self))
} else {
Ok(*self)
}
}
}
impl<'a> RowIndex for &'a str {
#[inline]
fn idx(&self, stmt: &Statement) -> Result<i32> {
stmt.column_index(*self)
}
}
#[cfg(test)]
mod test {
extern crate libsqlite3_sys as ffi;
@@ -1149,7 +1196,7 @@ mod test {
let path_string = path.to_str().unwrap();
let db = Connection::open(&path_string).unwrap();
let the_answer = db.query_row("SELECT x FROM foo", &[], |r| r.get::<i64>(0));
let the_answer: Result<i64> = db.query_row("SELECT x FROM foo", &[], |r| r.get(0));
assert_eq!(42i64, the_answer.unwrap());
}
@@ -1306,10 +1353,10 @@ mod test {
db.execute_batch(sql).unwrap();
assert_eq!(10i64,
db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::<i64>(0))
db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0))
.unwrap());
let result = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get::<i64>(0));
let result: Result<i64> = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get(0));
match result.unwrap_err() {
Error::QueryReturnedNoRows => (),
err => panic!("Unexpected error {}", err),
@@ -1343,7 +1390,7 @@ mod test {
assert_eq!(2i32, second.get(0));
match first.get_checked::<i32>(0).unwrap_err() {
match first.get_checked::<i32,i32>(0).unwrap_err() {
Error::GetFromStaleRow => (),
err => panic!("Unexpected error {}", err),
}