diff --git a/src/blob.rs b/src/blob.rs index 9f50980..c24fce9 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -159,9 +159,8 @@ impl<'conn> io::Read for Blob<'conn> { if n <= 0 { return Ok(0); } - let rc = unsafe { - ffi::sqlite3_blob_read(self.blob, mem::transmute(buf.as_ptr()), n, self.pos) - }; + let rc = + unsafe { ffi::sqlite3_blob_read(self.blob, mem::transmute(buf.as_ptr()), n, self.pos) }; self.conn .decode_result(rc) .map(|_| { @@ -353,7 +352,7 @@ mod test { { // ... but it should've written the first 10 bytes let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false) - .unwrap(); + .unwrap(); let mut bytes = [0u8; 10]; assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); assert_eq!(b"0123456701", &bytes); @@ -371,7 +370,7 @@ mod test { { // ... but it should've written the first 10 bytes let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false) - .unwrap(); + .unwrap(); let mut bytes = [0u8; 10]; assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); assert_eq!(b"aaaaaaaaaa", &bytes); diff --git a/src/convenient.rs b/src/convenient.rs index b7102f8..5f867b4 100644 --- a/src/convenient.rs +++ b/src/convenient.rs @@ -23,12 +23,14 @@ impl<'conn> Statement<'conn> { /// Return `true` if a query in the SQL statement it executes returns one or more rows /// and `false` if the SQL returns an empty set. pub fn exists(&mut self, params: &[&ToSql]) -> Result { - self.reset_if_needed(); let mut rows = try!(self.query(params)); - match rows.next() { - Some(_) => Ok(true), - None => Ok(false), - } + let exists = { + match rows.next() { + Some(_) => true, + None => false, + } + }; + Ok(exists) } } diff --git a/src/functions.rs b/src/functions.rs index 8fd78f0..5842880 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -333,7 +333,9 @@ impl<'a> Context<'a> { /// /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. -pub trait Aggregate where T: ToResult { +pub trait Aggregate + where T: ToResult +{ /// Initializes the aggregation context. Will be called prior to the first call /// to `step()` to set up the context for an invocation of the function. (Note: /// `init()` will not be called if the there are no rows.) @@ -769,16 +771,16 @@ mod test { fn test_varargs_function() { let db = Connection::open_in_memory().unwrap(); db.create_scalar_function("my_concat", -1, true, |ctx| { - let mut ret = String::new(); + let mut ret = String::new(); - for idx in 0..ctx.len() { - let s = try!(ctx.get::(idx)); - ret.push_str(&s); - } + for idx in 0..ctx.len() { + let s = try!(ctx.get::(idx)); + ret.push_str(&s); + } - Ok(ret) - }) - .unwrap(); + Ok(ret) + }) + .unwrap(); for &(expected, query) in &[("", "SELECT my_concat()"), ("onetwo", "SELECT my_concat('one', 'two')"), @@ -829,18 +831,18 @@ mod test { // sum should return NULL when given no columns (contrast with count below) let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)"; let result: Option = db.query_row(no_result, &[], |r| r.get(0)) - .unwrap(); + .unwrap(); assert!(result.is_none()); let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) - .unwrap(); + .unwrap(); assert_eq!(4, result); let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ 2, 1)"; let result: (i64, i64) = db.query_row(dual_sum, &[], |r| (r.get(0), r.get(1))) - .unwrap(); + .unwrap(); assert_eq!((4, 2), result); } @@ -856,7 +858,7 @@ mod test { let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)"; let result: i64 = db.query_row(single_sum, &[], |r| r.get(0)) - .unwrap(); + .unwrap(); assert_eq!(2, result); } } diff --git a/src/lib.rs b/src/lib.rs index 3460211..09e68bb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -521,8 +521,8 @@ impl Connection { impl fmt::Debug for Connection { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("Connection") - .field("path", &self.path) - .finish() + .field("path", &self.path) + .finish() } } @@ -708,7 +708,6 @@ pub type SqliteStatement<'conn> = Statement<'conn>; pub struct Statement<'conn> { conn: &'conn Connection, stmt: *mut ffi::sqlite3_stmt, - needs_reset: bool, column_count: c_int, } @@ -717,7 +716,6 @@ impl<'conn> Statement<'conn> { Statement { conn: conn, stmt: stmt, - needs_reset: false, column_count: unsafe { ffi::sqlite3_column_count(stmt) }, } } @@ -826,13 +824,10 @@ impl<'conn> Statement<'conn> { /// /// Will return `Err` if binding parameters fails. pub fn query<'a>(&'a mut self, params: &[&ToSql]) -> Result> { - self.reset_if_needed(); - unsafe { try!(self.bind_parameters(params)); } - self.needs_reset = true; Ok(Rows::new(self)) } @@ -906,15 +901,6 @@ impl<'conn> Statement<'conn> { Ok(()) } - fn reset_if_needed(&mut self) { - if self.needs_reset { - unsafe { - ffi::sqlite3_reset(self.stmt); - }; - self.needs_reset = false; - } - } - fn finalize_(&mut self) -> Result<()> { let r = unsafe { ffi::sqlite3_finalize(self.stmt) }; self.stmt = ptr::null_mut(); @@ -929,10 +915,10 @@ impl<'conn> fmt::Debug for Statement<'conn> { str::from_utf8(c_slice) }; f.debug_struct("Statement") - .field("conn", self.conn) - .field("stmt", &self.stmt) - .field("sql", &sql) - .finish() + .field("conn", self.conn) + .field("stmt", &self.stmt) + .field("sql", &sql) + .finish() } } @@ -949,7 +935,8 @@ pub struct MappedRows<'stmt, F> { map: F, } -impl<'stmt, T, F> Iterator for MappedRows<'stmt, F> where F: FnMut(&Row) -> T +impl<'stmt, T, F> Iterator for MappedRows<'stmt, F> + where F: FnMut(&Row) -> T { type Item = Result; @@ -974,7 +961,7 @@ impl<'stmt, T, E, F> Iterator for AndThenRows<'stmt, F> fn next(&mut self) -> Option { self.rows.next().map(|row_result| { row_result.map_err(E::from) - .and_then(|row| (self.map)(&row)) + .and_then(|row| (self.map)(&row)) }) } } @@ -1015,17 +1002,15 @@ pub type SqliteRows<'stmt> = Rows<'stmt>; /// `min`/`max` (which could return a stale row unless the last row happened to be the min or max, /// respectively). pub struct Rows<'stmt> { - stmt: &'stmt Statement<'stmt>, + stmt: Option<&'stmt Statement<'stmt>>, current_row: Rc>, - failed: bool, } impl<'stmt> Rows<'stmt> { fn new(stmt: &'stmt Statement<'stmt>) -> Rows<'stmt> { Rows { - stmt: stmt, + stmt: Some(stmt), current_row: Rc::new(Cell::new(0)), - failed: false, } } @@ -1035,31 +1020,47 @@ impl<'stmt> Rows<'stmt> { None => Err(Error::QueryReturnedNoRows), } } + + fn reset(&mut self) { + if let Some(stmt) = self.stmt.take() { + unsafe { + ffi::sqlite3_reset(stmt.stmt); + } + } + } } impl<'stmt> Iterator for Rows<'stmt> { type Item = Result>; fn next(&mut self) -> Option>> { - if self.failed { - return None; - } - match unsafe { ffi::sqlite3_step(self.stmt.stmt) } { - ffi::SQLITE_ROW => { - let current_row = self.current_row.get() + 1; - self.current_row.set(current_row); - Some(Ok(Row { - stmt: self.stmt, - current_row: self.current_row.clone(), - row_idx: current_row, - })) + self.stmt.and_then(|stmt| { + match unsafe { ffi::sqlite3_step(stmt.stmt) } { + ffi::SQLITE_ROW => { + let current_row = self.current_row.get() + 1; + self.current_row.set(current_row); + Some(Ok(Row { + stmt: stmt, + current_row: self.current_row.clone(), + row_idx: current_row, + })) + } + ffi::SQLITE_DONE => { + self.reset(); + None + } + code => { + self.reset(); + Some(Err(stmt.conn.decode_result(code).unwrap_err())) + } } - ffi::SQLITE_DONE => None, - code => { - self.failed = true; - Some(Err(self.stmt.conn.decode_result(code).unwrap_err())) - } - } + }) + } +} + +impl<'stmt> Drop for Rows<'stmt> { + fn drop(&mut self) { + self.reset(); } } diff --git a/src/named_params.rs b/src/named_params.rs index 34b3c67..3e62610 100644 --- a/src/named_params.rs +++ b/src/named_params.rs @@ -113,10 +113,7 @@ impl<'conn> Statement<'conn> { /// /// Will return `Err` if binding parameters fails. pub fn query_named<'a>(&'a mut self, params: &[(&str, &ToSql)]) -> Result> { - self.reset_if_needed(); try!(self.bind_parameters_named(params)); - - self.needs_reset = true; Ok(Rows::new(self)) } @@ -190,10 +187,9 @@ mod test { let mut stmt = db.prepare("INSERT INTO test (x, y) VALUES (:x, :y)").unwrap(); stmt.execute_named(&[(":x", &"one")]).unwrap(); - let result: Option = db.query_row("SELECT y FROM test WHERE x = 'one'", - &[], - |row| row.get(0)) - .unwrap(); + let result: Option = + db.query_row("SELECT y FROM test WHERE x = 'one'", &[], |row| row.get(0)) + .unwrap(); assert!(result.is_none()); } @@ -207,10 +203,9 @@ mod test { stmt.execute_named(&[(":x", &"one")]).unwrap(); stmt.execute_named(&[(":y", &"two")]).unwrap(); - let result: String = db.query_row("SELECT x FROM test WHERE y = 'two'", - &[], - |row| row.get(0)) - .unwrap(); + let result: String = + db.query_row("SELECT x FROM test WHERE y = 'two'", &[], |row| row.get(0)) + .unwrap(); assert_eq!(result, "one"); } } diff --git a/src/trace.rs b/src/trace.rs index 7dd1417..4cea711 100644 --- a/src/trace.rs +++ b/src/trace.rs @@ -4,7 +4,6 @@ use libc::{c_char, c_int, c_void}; use std::ffi::{CStr, CString}; use std::mem; use std::ptr; -use std::str; use std::time::Duration; use super::ffi; diff --git a/src/types/mod.rs b/src/types/mod.rs index 841242b..58b4dbd 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -379,8 +379,8 @@ mod test { let db = checked_memory_handle(); db.execute("INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", - &[]) - .unwrap(); + &[]) + .unwrap(); let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo").unwrap(); let mut rows = stmt.query(&[]).unwrap(); @@ -442,8 +442,8 @@ mod test { let db = checked_memory_handle(); db.execute("INSERT INTO foo(b, t, i, f) VALUES (X'0102', 'text', 1, 1.5)", - &[]) - .unwrap(); + &[]) + .unwrap(); let mut stmt = db.prepare("SELECT b, t, i, f, n FROM foo").unwrap(); let mut rows = stmt.query(&[]).unwrap();