diff --git a/src/lib.rs b/src/lib.rs index f12e5da..26fcc9f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -630,18 +630,26 @@ impl<'conn> SqliteStatement<'conn> { self.reset_if_needed(); unsafe { - assert!(params.len() as c_int == ffi::sqlite3_bind_parameter_count(self.stmt), - "incorrect number of parameters to query(): expected {}, got {}", - ffi::sqlite3_bind_parameter_count(self.stmt), - params.len()); - - for (i, p) in params.iter().enumerate() { - try!(self.conn.decode_result(p.bind_parameter(self.stmt, (i + 1) as c_int))); - } - - self.needs_reset = true; - Ok(SqliteRows::new(self)) + try!(self.bind_parameters(params)); } + + Ok(SqliteRows::new(self)) + } + + pub fn query_map<'a, 'map, T, F>(&'a mut self, params: &[&ToSql], f: &'map F) + -> SqliteResult> + where T: 'static, + F: Fn(MappedRow) -> T { + self.reset_if_needed(); + + unsafe { + try!(self.bind_parameters(params)); + } + + Ok(MappedRows { + stmt: self, + map: f + }) } /// Consumes the statement. @@ -652,6 +660,21 @@ impl<'conn> SqliteStatement<'conn> { self.finalize_() } + unsafe fn bind_parameters(&mut self, params: &[&ToSql]) -> SqliteResult<()> { + assert!(params.len() as c_int == ffi::sqlite3_bind_parameter_count(self.stmt), + "incorrect number of parameters to query(): expected {}, got {}", + ffi::sqlite3_bind_parameter_count(self.stmt), + params.len()); + + for (i, p) in params.iter().enumerate() { + try!(self.conn.decode_result(p.bind_parameter(self.stmt, (i + 1) as c_int))); + } + + self.needs_reset = true; + + Ok(()) + } + fn reset_if_needed(&mut self) { if self.needs_reset { unsafe { ffi::sqlite3_reset(self.stmt); }; @@ -679,6 +702,44 @@ impl<'conn> Drop for SqliteStatement<'conn> { } } +pub struct MappedRows<'stmt, 'map, T> { + stmt: &'stmt SqliteStatement<'stmt>, + map: &'map Fn(MappedRow) -> T +} + +impl<'stmt, 'map, T: 'static> Iterator for MappedRows<'stmt, 'map, T> { + type Item = SqliteResult; + + fn next(&mut self) -> Option> { + match unsafe { ffi::sqlite3_step(self.stmt.stmt) } { + ffi::SQLITE_ROW => { + Some(Ok((*self.map)(MappedRow(self.stmt)))) + }, + ffi::SQLITE_DONE => None, + code => { + Some(Err(self.stmt.conn.decode_result(code).unwrap_err())) + } + } + } +} + +pub struct MappedRow<'stmt>(&'stmt SqliteStatement<'stmt>); + +impl<'stmt> MappedRow<'stmt> { + pub fn get(&self, idx: c_int) -> T { + self.get_opt(idx).unwrap() + } + + pub fn get_opt(&self, idx: c_int) -> SqliteResult { + // Do assertions because these are logic errors. + // We can probably skip them in release builds. + assert!(idx >= 0); + assert!(idx < unsafe { ffi::sqlite3_column_count(self.0.stmt) }); + + unsafe { FromSql::column_result(self.0.stmt, idx) } + } +} + /// An iterator over the resulting rows of a query. /// /// ## Warning @@ -947,6 +1008,24 @@ mod test { } } + #[test] + fn test_query_map() { + let db = checked_memory_handle(); + let sql = "BEGIN; + CREATE TABLE foo(x INTEGER, y TEXT); + INSERT INTO foo VALUES(4, \"hello\"); + INSERT INTO foo VALUES(3, \", \"); + INSERT INTO foo VALUES(2, \"world\"); + INSERT INTO foo VALUES(1, \"!\"); + END;"; + db.execute_batch(sql).unwrap(); + + let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); + let results: SqliteResult> = query.query_map(&[], &(|row| row.get(1))).unwrap().collect(); + + assert_eq!(results.unwrap().concat(), "hello, world!"); + } + #[test] fn test_query_row() { let db = checked_memory_handle();