mirror of
https://github.com/isar/rusqlite.git
synced 2025-09-16 12:42:18 +08:00
Introduce RowIndex trait (like in rust-postgres)
This commit is contained in:
71
src/lib.rs
71
src/lib.rs
@@ -728,6 +728,23 @@ impl<'conn> Statement<'conn> {
|
||||
cols
|
||||
}
|
||||
|
||||
/// Returns the column index in the result set for a given column name.
|
||||
/// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next.
|
||||
///
|
||||
/// # Failure
|
||||
/// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`.
|
||||
pub fn column_index(&self, name: &str) -> Result<i32> {
|
||||
let bytes = name.as_bytes();
|
||||
let n = self.column_count;
|
||||
for i in 0..n {
|
||||
let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) };
|
||||
if bytes == slice.to_bytes() {
|
||||
return Ok(i);
|
||||
}
|
||||
}
|
||||
Err(Error::InvalidColumnName(String::from(name)))
|
||||
}
|
||||
|
||||
/// Execute the prepared statement.
|
||||
///
|
||||
/// On success, returns the number of rows that were changed or inserted or deleted (via
|
||||
@@ -1078,9 +1095,11 @@ impl<'stmt> Row<'stmt> {
|
||||
///
|
||||
/// ## Failure
|
||||
///
|
||||
/// Panics if the underlying SQLite column type is not a valid type as a source for `T`.
|
||||
///
|
||||
/// Panics if `idx` is outside the range of columns in the returned query or if this row
|
||||
/// is stale.
|
||||
pub fn get<T: FromSql>(&self, idx: c_int) -> T {
|
||||
pub fn get<I: RowIndex, T: FromSql>(&self, idx: I) -> T {
|
||||
self.get_checked(idx).unwrap()
|
||||
}
|
||||
|
||||
@@ -1088,19 +1107,22 @@ impl<'stmt> Row<'stmt> {
|
||||
///
|
||||
/// ## Failure
|
||||
///
|
||||
/// Returns a `SQLITE_MISMATCH`-coded `Error` if the underlying SQLite column
|
||||
/// Returns an `Error::InvalidColumnType` if the underlying SQLite column
|
||||
/// type is not a valid type as a source for `T`.
|
||||
///
|
||||
/// Returns a `SQLITE_MISUSE`-coded `Error` if `idx` is outside the valid column range
|
||||
/// for this row or if this row is stale.
|
||||
pub fn get_checked<T: FromSql>(&self, idx: c_int) -> Result<T> {
|
||||
/// Returns an `Error::InvalidColumnIndex` if `idx` is outside the valid column range
|
||||
/// for this row.
|
||||
///
|
||||
/// Returns an `Error::InvalidColumnName` if `idx` is not a valid column name
|
||||
/// for this row.
|
||||
///
|
||||
/// Returns an `Error::GetFromStaleRow` if this row is stale.
|
||||
pub fn get_checked<I: RowIndex, T: FromSql>(&self, idx: I) -> Result<T> {
|
||||
if self.row_idx != self.current_row.get() {
|
||||
return Err(Error::GetFromStaleRow);
|
||||
}
|
||||
unsafe {
|
||||
if idx < 0 || idx >= self.stmt.column_count {
|
||||
return Err(Error::InvalidColumnIndex(idx));
|
||||
}
|
||||
let idx = try!(idx.idx(self.stmt));
|
||||
|
||||
if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) {
|
||||
FromSql::column_result(self.stmt.stmt, idx)
|
||||
@@ -1111,6 +1133,31 @@ impl<'stmt> Row<'stmt> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A trait implemented by types that can index into columns of a row.
|
||||
pub trait RowIndex {
|
||||
/// Returns the index of the appropriate column, or `None` if no such
|
||||
/// column exists.
|
||||
fn idx(&self, stmt: &Statement) -> Result<i32>;
|
||||
}
|
||||
|
||||
impl RowIndex for i32 {
|
||||
#[inline]
|
||||
fn idx(&self, stmt: &Statement) -> Result<i32> {
|
||||
if *self < 0 || *self >= stmt.column_count {
|
||||
Err(Error::InvalidColumnIndex(*self))
|
||||
} else {
|
||||
Ok(*self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> RowIndex for &'a str {
|
||||
#[inline]
|
||||
fn idx(&self, stmt: &Statement) -> Result<i32> {
|
||||
stmt.column_index(*self)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
extern crate libsqlite3_sys as ffi;
|
||||
@@ -1149,7 +1196,7 @@ mod test {
|
||||
|
||||
let path_string = path.to_str().unwrap();
|
||||
let db = Connection::open(&path_string).unwrap();
|
||||
let the_answer = db.query_row("SELECT x FROM foo", &[], |r| r.get::<i64>(0));
|
||||
let the_answer: Result<i64> = db.query_row("SELECT x FROM foo", &[], |r| r.get(0));
|
||||
|
||||
assert_eq!(42i64, the_answer.unwrap());
|
||||
}
|
||||
@@ -1306,10 +1353,10 @@ mod test {
|
||||
db.execute_batch(sql).unwrap();
|
||||
|
||||
assert_eq!(10i64,
|
||||
db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get::<i64>(0))
|
||||
db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0))
|
||||
.unwrap());
|
||||
|
||||
let result = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get::<i64>(0));
|
||||
let result: Result<i64> = db.query_row("SELECT x FROM foo WHERE x > 5", &[], |r| r.get(0));
|
||||
match result.unwrap_err() {
|
||||
Error::QueryReturnedNoRows => (),
|
||||
err => panic!("Unexpected error {}", err),
|
||||
@@ -1343,7 +1390,7 @@ mod test {
|
||||
|
||||
assert_eq!(2i32, second.get(0));
|
||||
|
||||
match first.get_checked::<i32>(0).unwrap_err() {
|
||||
match first.get_checked::<i32,i32>(0).unwrap_err() {
|
||||
Error::GetFromStaleRow => (),
|
||||
err => panic!("Unexpected error {}", err),
|
||||
}
|
||||
|
Reference in New Issue
Block a user