Add test and fix for invalid cached column_count.

Issue raised in
https://github.com/jgallagher/rusqlite/pull/113#issuecomment-220122048.
This commit is contained in:
John Gallagher 2016-05-18 22:19:04 -05:00
parent 437a06fca3
commit 74b57ee47a
2 changed files with 42 additions and 11 deletions

View File

@ -113,9 +113,9 @@ impl StatementCache {
// Will return `Err` if no cached statement can be found and the underlying SQLite prepare // Will return `Err` if no cached statement can be found and the underlying SQLite prepare
// call fails. // call fails.
fn get<'conn>(&'conn self, fn get<'conn>(&'conn self,
conn: &'conn Connection, conn: &'conn Connection,
sql: &str) sql: &str)
-> Result<CachedStatement<'conn>> { -> Result<CachedStatement<'conn>> {
let mut cache = self.0.borrow_mut(); let mut cache = self.0.borrow_mut();
let stmt = match cache.remove(sql) { let stmt = match cache.remove(sql) {
Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)), Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
@ -232,4 +232,38 @@ mod test {
} }
assert_eq!(0, cache.len()); assert_eq!(0, cache.len());
} }
#[test]
fn test_ddl() {
let db = Connection::open_in_memory().unwrap();
db.execute_batch(r#"
CREATE TABLE foo (x INT);
INSERT INTO foo VALUES (1);
"#)
.unwrap();
let sql = "SELECT * FROM foo";
{
let mut stmt = db.prepare_cached(sql).unwrap();
assert_eq!(1i32,
stmt.query_map(&[], |r| r.get(0)).unwrap().next().unwrap().unwrap());
}
db.execute_batch(r#"
ALTER TABLE foo ADD COLUMN y INT;
UPDATE foo SET y = 2;
"#)
.unwrap();
{
let mut stmt = db.prepare_cached(sql).unwrap();
assert_eq!((1i32, 2i32),
stmt.query_map(&[], |r| (r.get(0), r.get(1)))
.unwrap()
.next()
.unwrap()
.unwrap());
}
}
} }

View File

@ -720,22 +720,19 @@ pub type SqliteStatement<'conn> = Statement<'conn>;
pub struct Statement<'conn> { pub struct Statement<'conn> {
conn: &'conn Connection, conn: &'conn Connection,
stmt: RawStatement, stmt: RawStatement,
column_count: c_int,
} }
impl<'conn> Statement<'conn> { impl<'conn> Statement<'conn> {
fn new(conn: &Connection, stmt: RawStatement) -> Statement { fn new(conn: &Connection, stmt: RawStatement) -> Statement {
let column_count = stmt.column_count();
Statement { Statement {
conn: conn, conn: conn,
stmt: stmt, stmt: stmt,
column_count: column_count,
} }
} }
/// Get all the column names in the result set of the prepared statement. /// Get all the column names in the result set of the prepared statement.
pub fn column_names(&self) -> Vec<&str> { pub fn column_names(&self) -> Vec<&str> {
let n = self.column_count; let n = self.column_count();
let mut cols = Vec::with_capacity(n as usize); let mut cols = Vec::with_capacity(n as usize);
for i in 0..n { for i in 0..n {
let slice = self.stmt.column_name(i); let slice = self.stmt.column_name(i);
@ -747,7 +744,7 @@ impl<'conn> Statement<'conn> {
/// Return the number of columns in the result set returned by the prepared statement. /// Return the number of columns in the result set returned by the prepared statement.
pub fn column_count(&self) -> i32 { pub fn column_count(&self) -> i32 {
self.column_count self.stmt.column_count()
} }
/// Returns the column index in the result set for a given column name. /// Returns the column index in the result set for a given column name.
@ -757,7 +754,7 @@ impl<'conn> Statement<'conn> {
/// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`. /// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`.
pub fn column_index(&self, name: &str) -> Result<i32> { pub fn column_index(&self, name: &str) -> Result<i32> {
let bytes = name.as_bytes(); let bytes = name.as_bytes();
let n = self.column_count; let n = self.column_count();
for i in 0..n { for i in 0..n {
if bytes == self.stmt.column_name(i).to_bytes() { if bytes == self.stmt.column_name(i).to_bytes() {
return Ok(i); return Ok(i);
@ -799,7 +796,7 @@ impl<'conn> Statement<'conn> {
self.stmt.reset(); self.stmt.reset();
match r { match r {
ffi::SQLITE_DONE => { ffi::SQLITE_DONE => {
if self.column_count == 0 { if self.column_count() == 0 {
Ok(self.conn.changes()) Ok(self.conn.changes())
} else { } else {
Err(Error::ExecuteReturnedResults) Err(Error::ExecuteReturnedResults)
@ -1166,7 +1163,7 @@ pub trait RowIndex {
impl RowIndex for i32 { impl RowIndex for i32 {
#[inline] #[inline]
fn idx(&self, stmt: &Statement) -> Result<i32> { fn idx(&self, stmt: &Statement) -> Result<i32> {
if *self < 0 || *self >= stmt.column_count { if *self < 0 || *self >= stmt.column_count() {
Err(Error::InvalidColumnIndex(*self)) Err(Error::InvalidColumnIndex(*self))
} else { } else {
Ok(*self) Ok(*self)