From 71a20048949dbe72ebce89e9ec966e4ae7421e1c Mon Sep 17 00:00:00 2001 From: gwenn Date: Sun, 10 Mar 2019 12:58:20 +0100 Subject: [PATCH] Add Rows::map method --- Cargo.toml | 1 + src/cache.rs | 15 +++++---------- src/lib.rs | 8 +++----- src/row.rs | 31 ++++++++++++++++++++++++++++++- src/session.rs | 9 +++++---- src/statement.rs | 6 +++--- src/vtab/csvtab.rs | 4 ++-- 7 files changed, 49 insertions(+), 25 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7581e09..bf7d4b9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,6 +58,7 @@ serde_json = { version = "1.0", optional = true } csv = { version = "1.0", optional = true } lazy_static = { version = "1.0", optional = true } byteorder = { version = "1.2", features = ["i128"], optional = true } +fallible-iterator = "0.1" fallible-streaming-iterator = "0.1" memchr = "2.2.0" diff --git a/src/cache.rs b/src/cache.rs index a93410f..dfece9e 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -151,6 +151,7 @@ impl StatementCache { #[cfg(test)] mod test { + use fallible_iterator::FallibleIterator; use super::StatementCache; use crate::{Connection, NO_PARAMS}; @@ -277,12 +278,9 @@ mod test { { let mut stmt = db.prepare_cached(sql).unwrap(); assert_eq!( - 1i32, - stmt.query_map::(NO_PARAMS, |r| r.get(0)) - .unwrap() + Ok(Some(1i32)), + stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)) .next() - .unwrap() - .unwrap() ); } @@ -297,12 +295,9 @@ mod test { { let mut stmt = db.prepare_cached(sql).unwrap(); assert_eq!( - (1i32, 2i32), - stmt.query_map(NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?))) - .unwrap() + Ok(Some((1i32, 2i32))), + stmt.query(NO_PARAMS).unwrap().map(|r| Ok((r.get(0)?, r.get(1)?))) .next() - .unwrap() - .unwrap() ); } } diff --git a/src/lib.rs b/src/lib.rs index dfa5c22..f7c875d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -77,8 +77,6 @@ use std::str; use std::sync::atomic::Ordering; use std::sync::{Arc, Mutex}; -pub use fallible_streaming_iterator::FallibleStreamingIterator; - use crate::cache::StatementCache; use crate::inner_connection::{InnerConnection, BYPASS_SQLITE_INIT}; use crate::raw_statement::RawStatement; @@ -846,6 +844,7 @@ unsafe fn db_filename(_: *mut ffi::sqlite3) -> Option { #[cfg(test)] mod test { + use fallible_iterator::FallibleIterator; use self::tempdir::TempDir; pub use super::*; use crate::ffi; @@ -1132,8 +1131,7 @@ mod test { let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); let results: Result> = query - .query_map(NO_PARAMS, |row| row.get(1)) - .unwrap() + .query(NO_PARAMS).unwrap().map(|row| row.get(1)) .collect(); assert_eq!(results.unwrap().concat(), "hello, world!"); @@ -1322,7 +1320,7 @@ mod test { .prepare("SELECT interrupt() FROM (SELECT 1 UNION SELECT 2 UNION SELECT 3)") .unwrap(); - let result: Result> = stmt.query_map(NO_PARAMS, |r| r.get(0)).unwrap().collect(); + let result: Result> = stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)).collect(); match result.unwrap_err() { Error::SqliteFailure(err, _) => { diff --git a/src/row.rs b/src/row.rs index 72f8a7c..9188be4 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,6 +1,8 @@ +use fallible_iterator::FallibleIterator; +use fallible_streaming_iterator::FallibleStreamingIterator; use std::{convert, result}; -use super::{Error, FallibleStreamingIterator, Result, Statement}; +use super::{Error, Result, Statement}; use crate::types::{FromSql, FromSqlError, ValueRef}; /// An handle for the resulting rows of a query. @@ -32,6 +34,13 @@ impl<'stmt> Rows<'stmt> { self.advance()?; Ok((*self).get()) } + + pub fn map(self, f: F) -> Map<'stmt, F> + where + F: FnMut(&Row<'_>) -> Result, + { + Map { rows: self, f: f } + } } impl<'stmt> Rows<'stmt> { @@ -56,6 +65,26 @@ impl Drop for Rows<'_> { } } +pub struct Map<'stmt, F> { + rows: Rows<'stmt>, + f: F, +} + +impl FallibleIterator for Map<'_, F> +where + F: FnMut(&Row<'_>) -> Result, +{ + type Error = Error; + type Item = B; + + fn next(&mut self) -> Result> { + match self.rows.next()? { + Some(v) => Ok(Some((self.f)(v)?)), + None => Ok(None), + } + } +} + /// An iterator over the mapped resulting rows of a query. pub struct MappedRows<'stmt, F> { rows: Rows<'stmt>, diff --git a/src/session.rs b/src/session.rs index f56b372..2546520 100644 --- a/src/session.rs +++ b/src/session.rs @@ -10,13 +10,13 @@ use std::panic::{catch_unwind, RefUnwindSafe}; use std::ptr; use std::slice::{from_raw_parts, from_raw_parts_mut}; +use fallible_streaming_iterator::FallibleStreamingIterator; + use crate::error::error_from_sqlite_code; use crate::ffi; use crate::hooks::Action; use crate::types::ValueRef; -use crate::{ - errmsg_to_string, str_to_cstring, Connection, DatabaseName, FallibleStreamingIterator, Result, -}; +use crate::{errmsg_to_string, str_to_cstring, Connection, DatabaseName, Result}; // https://sqlite.org/session.html @@ -720,10 +720,11 @@ unsafe extern "C" fn x_output(p_out: *mut c_void, data: *const c_void, len: c_in #[cfg(test)] mod test { use std::sync::atomic::{AtomicBool, Ordering}; + use fallible_streaming_iterator::FallibleStreamingIterator; use super::{Changeset, ChangesetIter, ConflictAction, ConflictType, Session}; use crate::hooks::Action; - use crate::{Connection, FallibleStreamingIterator}; + use crate::Connection; fn one_changeset() -> Changeset { let db = Connection::open_in_memory().unwrap(); diff --git a/src/statement.rs b/src/statement.rs index 850b6d0..c2401c3 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -168,7 +168,7 @@ impl Statement<'_> { /// ## Example /// /// ```rust,no_run - /// # use rusqlite::{Connection, FallibleStreamingIterator, Result, NO_PARAMS}; + /// # use rusqlite::{Connection, Result, NO_PARAMS}; /// fn get_names(conn: &Connection) -> Result> { /// let mut stmt = conn.prepare("SELECT name FROM people")?; /// let mut rows = stmt.query(NO_PARAMS)?; @@ -204,7 +204,7 @@ impl Statement<'_> { /// ## Example /// /// ```rust,no_run - /// # use rusqlite::{Connection, FallibleStreamingIterator, Result}; + /// # use rusqlite::{Connection, Result}; /// fn query(conn: &Connection) -> Result<()> { /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; /// let mut rows = stmt.query_named(&[(":name", &"one")])?; @@ -219,7 +219,7 @@ impl Statement<'_> { /// and so the above example could also be written as: /// /// ```rust,no_run - /// # use rusqlite::{Connection, FallibleStreamingIterator, Result, named_params}; + /// # use rusqlite::{Connection, Result, named_params}; /// fn query(conn: &Connection) -> Result<()> { /// let mut stmt = conn.prepare("SELECT * FROM test where name = :name")?; /// let mut rows = stmt.query_named(named_params!{ ":name": "one" })?; diff --git a/src/vtab/csvtab.rs b/src/vtab/csvtab.rs index 8e733f4..7e52655 100644 --- a/src/vtab/csvtab.rs +++ b/src/vtab/csvtab.rs @@ -345,6 +345,7 @@ impl From for Error { #[cfg(test)] mod test { + use fallible_iterator::FallibleIterator; use crate::vtab::csvtab; use crate::{Connection, Result, NO_PARAMS}; @@ -363,8 +364,7 @@ mod test { } let ids: Result> = s - .query_map(NO_PARAMS, |row| row.get::<_, i32>(0)) - .unwrap() + .query(NO_PARAMS).unwrap().map(|row| row.get::<_, i32>(0)) .collect(); let sum = ids.unwrap().iter().sum::(); assert_eq!(sum, 15);