diff --git a/Cargo.toml b/Cargo.toml index 293779b..7581e09 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,7 @@ csv = { version = "1.0", optional = true } lazy_static = { version = "1.0", optional = true } byteorder = { version = "1.2", features = ["i128"], optional = true } fallible-streaming-iterator = "0.1" +memchr = "2.2.0" [dev-dependencies] tempdir = "0.3" diff --git a/README.md b/README.md index f2d983e..268f23f 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ an interface similar to [rust-postgres](https://github.com/sfackler/rust-postgre ```rust use rusqlite::types::ToSql; -use rusqlite::{Connection, NO_PARAMS}; +use rusqlite::{Connection, Result, NO_PARAMS}; use time::Timespec; #[derive(Debug)] @@ -22,8 +22,8 @@ struct Person { data: Option>, } -fn main() { - let conn = Connection::open_in_memory().unwrap(); +fn main() -> Result<()> { + let conn = Connection::open_in_memory()?; conn.execute( "CREATE TABLE person ( @@ -33,7 +33,7 @@ fn main() { data BLOB )", NO_PARAMS, - ).unwrap(); + )?; let me = Person { id: 0, name: "Steven".to_string(), @@ -44,22 +44,22 @@ fn main() { "INSERT INTO person (name, time_created, data) VALUES (?1, ?2, ?3)", &[&me.name as &ToSql, &me.time_created, &me.data], - ).unwrap(); + )?; let mut stmt = conn - .prepare("SELECT id, name, time_created, data FROM person") - .unwrap(); + .prepare("SELECT id, name, time_created, data FROM person")?; let person_iter = stmt - .query_map(NO_PARAMS, |row| Person { - id: row.get(0), - name: row.get(1), - time_created: row.get(2), - data: row.get(3), - }).unwrap(); + .query_map(NO_PARAMS, |row| Ok(Person { + id: row.get(0)?, + name: row.get(1)?, + time_created: row.get(2)?, + data: row.get(3)?, + }))?; for person in person_iter { println!("Found person {:?}", person.unwrap()); } + Ok(()) } ``` diff --git a/appveyor.yml b/appveyor.yml index db0bfad..469e493 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -16,10 +16,10 @@ install: - rustc -V - cargo -V # download SQLite dll (useful only when the `bundled` feature is not set) - - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-dll-win64-x64-3250000.zip -FileName sqlite-dll-win64-x64.zip + - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-dll-win64-x64-3250200.zip -FileName sqlite-dll-win64-x64.zip - if not defined VCPKG_DEFAULT_TRIPLET 7z e sqlite-dll-win64-x64.zip -y > nul # download SQLite headers (useful only when the `bundled` feature is not set) - - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-amalgamation-3250000.zip -FileName sqlite-amalgamation.zip + - appveyor-retry appveyor DownloadFile https://sqlite.org/2018/sqlite-amalgamation-3250200.zip -FileName sqlite-amalgamation.zip - if not defined VCPKG_DEFAULT_TRIPLET 7z e sqlite-amalgamation.zip -y > nul # specify where the SQLite dll has been downloaded (useful only when the `bundled` feature is not set) - if not defined VCPKG_DEFAULT_TRIPLET SET SQLITE3_LIB_DIR=%APPVEYOR_BUILD_FOLDER% @@ -33,8 +33,8 @@ build: false test_script: - cargo test --lib --verbose - cargo test --lib --verbose --features bundled - - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab" - - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab buildtime_bindgen" + - cargo test --lib --features "backup blob chrono functions hooks limits load_extension serde_json trace" + - cargo test --lib --features "backup blob chrono functions hooks limits load_extension serde_json trace buildtime_bindgen" - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab bundled" - cargo test --lib --features "backup blob chrono csvtab functions hooks limits load_extension serde_json trace vtab bundled buildtime_bindgen" diff --git a/src/blob.rs b/src/blob.rs index 0f3d168..372fc4c 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -16,43 +16,40 @@ //! ```rust //! use rusqlite::blob::ZeroBlob; //! use rusqlite::{Connection, DatabaseName, NO_PARAMS}; +//! use std::error::Error; //! use std::io::{Read, Seek, SeekFrom, Write}; //! -//! fn main() { -//! let db = Connection::open_in_memory().unwrap(); -//! db.execute_batch("CREATE TABLE test (content BLOB);") -//! .unwrap(); +//! fn main() -> Result<(), Box> { +//! let db = Connection::open_in_memory()?; +//! db.execute_batch("CREATE TABLE test (content BLOB);")?; //! db.execute( //! "INSERT INTO test (content) VALUES (ZEROBLOB(10))", //! NO_PARAMS, -//! ) -//! .unwrap(); +//! )?; //! //! let rowid = db.last_insert_rowid(); -//! let mut blob = db -//! .blob_open(DatabaseName::Main, "test", "content", rowid, false) -//! .unwrap(); +//! let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false)?; //! //! // Make sure to test that the number of bytes written matches what you expect; //! // if you try to write too much, the data will be truncated to the size of the //! // BLOB. -//! let bytes_written = blob.write(b"01234567").unwrap(); +//! let bytes_written = blob.write(b"01234567")?; //! assert_eq!(bytes_written, 8); //! //! // Same guidance - make sure you check the number of bytes read! -//! blob.seek(SeekFrom::Start(0)).unwrap(); +//! blob.seek(SeekFrom::Start(0))?; //! let mut buf = [0u8; 20]; -//! let bytes_read = blob.read(&mut buf[..]).unwrap(); +//! let bytes_read = blob.read(&mut buf[..])?; //! assert_eq!(bytes_read, 10); // note we read 10 bytes because the blob has size 10 //! -//! db.execute("INSERT INTO test (content) VALUES (?)", &[ZeroBlob(64)]) -//! .unwrap(); +//! db.execute("INSERT INTO test (content) VALUES (?)", &[ZeroBlob(64)])?; //! //! // given a new row ID, we can reopen the blob on that row //! let rowid = db.last_insert_rowid(); -//! blob.reopen(rowid).unwrap(); +//! blob.reopen(rowid)?; //! //! assert_eq!(blob.size(), 64); +//! Ok(()) //! } //! ``` use std::cmp::min; diff --git a/src/busy.rs b/src/busy.rs index e0ee835..f801e3f 100644 --- a/src/busy.rs +++ b/src/busy.rs @@ -82,7 +82,7 @@ mod test { use std::time::Duration; use tempdir; - use crate::{Connection, Error, ErrorCode, TransactionBehavior, NO_PARAMS}; + use crate::{Connection, Error, ErrorCode, Result, TransactionBehavior, NO_PARAMS}; #[test] fn test_default_busy() { @@ -94,7 +94,7 @@ mod test { .transaction_with_behavior(TransactionBehavior::Exclusive) .unwrap(); let db2 = Connection::open(&path).unwrap(); - let r = db2.query_row("PRAGMA schema_version", NO_PARAMS, |_| unreachable!()); + let r: Result<()> = db2.query_row("PRAGMA schema_version", NO_PARAMS, |_| unreachable!()); match r.unwrap_err() { Error::SqliteFailure(err, _) => { assert_eq!(err.code, ErrorCode::DatabaseBusy); @@ -127,7 +127,7 @@ mod test { assert_eq!(tx.recv().unwrap(), 1); let _ = db2 .query_row("PRAGMA schema_version", NO_PARAMS, |row| { - row.get_checked::<_, i32>(0) + row.get::<_, i32>(0) }) .expect("unexpected error"); @@ -166,7 +166,7 @@ mod test { assert_eq!(tx.recv().unwrap(), 1); let _ = db2 .query_row("PRAGMA schema_version", NO_PARAMS, |row| { - row.get_checked::<_, i32>(0) + row.get::<_, i32>(0) }) .expect("unexpected error"); assert_eq!(CALLED.load(Ordering::Relaxed), true); diff --git a/src/cache.rs b/src/cache.rs index 5bc7d47..a93410f 100644 --- a/src/cache.rs +++ b/src/cache.rs @@ -152,7 +152,7 @@ impl StatementCache { #[cfg(test)] mod test { use super::StatementCache; - use crate::{Connection, FallibleStreamingIterator, NO_PARAMS}; + use crate::{Connection, NO_PARAMS}; impl StatementCache { fn clear(&self) { @@ -277,9 +277,12 @@ mod test { { let mut stmt = db.prepare_cached(sql).unwrap(); assert_eq!( - Ok(Some(&1i32)), - stmt.query(NO_PARAMS).unwrap().map(|r| r.get(0)) + 1i32, + stmt.query_map::(NO_PARAMS, |r| r.get(0)) + .unwrap() .next() + .unwrap() + .unwrap() ); } @@ -294,9 +297,12 @@ mod test { { let mut stmt = db.prepare_cached(sql).unwrap(); assert_eq!( - Ok(Some(&(1i32, 2i32))), - stmt.query(NO_PARAMS).unwrap().map(|r| (r.get(0), r.get(1))) + (1i32, 2i32), + stmt.query_map(NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?))) + .unwrap() .next() + .unwrap() + .unwrap() ); } } diff --git a/src/config.rs b/src/config.rs index 36337f0..a10faac 100644 --- a/src/config.rs +++ b/src/config.rs @@ -25,12 +25,18 @@ pub enum DbConfig { impl Connection { /// Returns the current value of a `config`. /// - /// - SQLITE_DBCONFIG_ENABLE_FKEY: return `false` or `true` to indicate whether FK enforcement is off or on - /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: return `false` or `true` to indicate whether triggers are disabled or enabled - /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: return `false` or `true` to indicate whether fts3_tokenizer are disabled or enabled - /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: return `false` to indicate checkpoints-on-close are not disabled or `true` if they are - /// - SQLITE_DBCONFIG_ENABLE_QPSG: return `false` or `true` to indicate whether the QPSG is disabled or enabled - /// - SQLITE_DBCONFIG_TRIGGER_EQP: return `false` to indicate output-for-trigger are not disabled or `true` if it is + /// - SQLITE_DBCONFIG_ENABLE_FKEY: return `false` or `true` to indicate + /// whether FK enforcement is off or on + /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: return `false` or `true` to indicate + /// whether triggers are disabled or enabled + /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: return `false` or `true` to + /// indicate whether fts3_tokenizer are disabled or enabled + /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: return `false` to indicate + /// checkpoints-on-close are not disabled or `true` if they are + /// - SQLITE_DBCONFIG_ENABLE_QPSG: return `false` or `true` to indicate + /// whether the QPSG is disabled or enabled + /// - SQLITE_DBCONFIG_TRIGGER_EQP: return `false` to indicate + /// output-for-trigger are not disabled or `true` if it is pub fn db_config(&self, config: DbConfig) -> Result { let c = self.db.borrow(); unsafe { @@ -47,12 +53,18 @@ impl Connection { /// Make configuration changes to a database connection /// - /// - SQLITE_DBCONFIG_ENABLE_FKEY: `false` to disable FK enforcement, `true` to enable FK enforcement - /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: `false` to disable triggers, `true` to enable triggers - /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: `false` to disable fts3_tokenizer(), `true` to enable fts3_tokenizer() - /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: `false` (the default) to enable checkpoints-on-close, `true` to disable them - /// - SQLITE_DBCONFIG_ENABLE_QPSG: `false` to disable the QPSG, `true` to enable QPSG - /// - SQLITE_DBCONFIG_TRIGGER_EQP: `false` to disable output for trigger programs, `true` to enable it + /// - SQLITE_DBCONFIG_ENABLE_FKEY: `false` to disable FK enforcement, `true` + /// to enable FK enforcement + /// - SQLITE_DBCONFIG_ENABLE_TRIGGER: `false` to disable triggers, `true` to + /// enable triggers + /// - SQLITE_DBCONFIG_ENABLE_FTS3_TOKENIZER: `false` to disable + /// fts3_tokenizer(), `true` to enable fts3_tokenizer() + /// - SQLITE_DBCONFIG_NO_CKPT_ON_CLOSE: `false` (the default) to enable + /// checkpoints-on-close, `true` to disable them + /// - SQLITE_DBCONFIG_ENABLE_QPSG: `false` to disable the QPSG, `true` to + /// enable QPSG + /// - SQLITE_DBCONFIG_TRIGGER_EQP: `false` to disable output for trigger + /// programs, `true` to enable it pub fn set_db_config(&self, config: DbConfig, new_val: bool) -> Result { let c = self.db.borrow_mut(); unsafe { @@ -78,11 +90,25 @@ mod test { let db = Connection::open_in_memory().unwrap(); let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY).unwrap(); - assert_eq!(db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY, opposite), Ok(opposite)); - assert_eq!(db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY), Ok(opposite)); + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY), + Ok(opposite) + ); - let opposite = !db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER).unwrap(); - assert_eq!(db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER, opposite), Ok(opposite)); - assert_eq!(db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER), Ok(opposite)); + let opposite = !db + .db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER) + .unwrap(); + assert_eq!( + db.set_db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER, opposite), + Ok(opposite) + ); + assert_eq!( + db.db_config(DbConfig::SQLITE_DBCONFIG_ENABLE_TRIGGER), + Ok(opposite) + ); } } diff --git a/src/context.rs b/src/context.rs index a7bb468..ad0a3ad 100644 --- a/src/context.rs +++ b/src/context.rs @@ -7,7 +7,7 @@ use std::rc::Rc; use crate::ffi; use crate::ffi::sqlite3_context; -use crate::str_to_cstring; +use crate::str_for_sqlite; use crate::types::{ToSqlOutput, ValueRef}; #[cfg(feature = "array")] use crate::vtab::array::{free_array, ARRAY_TYPE}; @@ -38,25 +38,20 @@ pub(crate) unsafe fn set_result(ctx: *mut sqlite3_context, result: &ToSqlOutput< ValueRef::Real(r) => ffi::sqlite3_result_double(ctx, r), ValueRef::Text(s) => { let length = s.len(); - if length > ::std::i32::MAX as usize { + if length > c_int::max_value() as usize { ffi::sqlite3_result_error_toobig(ctx); } else { - let c_str = match str_to_cstring(s) { + let (c_str, len, destructor) = match str_for_sqlite(s) { Ok(c_str) => c_str, // TODO sqlite3_result_error Err(_) => return ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), }; - let destructor = if length > 0 { - ffi::SQLITE_TRANSIENT() - } else { - ffi::SQLITE_STATIC() - }; - ffi::sqlite3_result_text(ctx, c_str.as_ptr(), length as c_int, destructor); + ffi::sqlite3_result_text(ctx, c_str, len, destructor); } } ValueRef::Blob(b) => { let length = b.len(); - if length > ::std::i32::MAX as usize { + if length > c_int::max_value() as usize { ffi::sqlite3_result_error_toobig(ctx); } else if length == 0 { ffi::sqlite3_result_zeroblob(ctx, 0) diff --git a/src/functions.rs b/src/functions.rs index ddacb84..a4773a0 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -34,19 +34,18 @@ //! }) //! } //! -//! fn main() { -//! let db = Connection::open_in_memory().unwrap(); -//! add_regexp_function(&db).unwrap(); +//! fn main() -> Result<()> { +//! let db = Connection::open_in_memory()?; +//! add_regexp_function(&db)?; //! -//! let is_match: bool = db -//! .query_row( -//! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')", -//! NO_PARAMS, -//! |row| row.get(0), -//! ) -//! .unwrap(); +//! let is_match: bool = db.query_row( +//! "SELECT regexp('[aeiou]*', 'aaaaeeeiii')", +//! NO_PARAMS, +//! |row| row.get(0), +//! )?; //! //! assert!(is_match); +//! Ok(()) //! } //! ``` use std::error::Error as StdError; @@ -771,7 +770,7 @@ mod test { let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \ 2, 1)"; let result: (i64, i64) = db - .query_row(dual_sum, NO_PARAMS, |r| (r.get(0), r.get(1))) + .query_row(dual_sum, NO_PARAMS, |r| Ok((r.get(0)?, r.get(1)?))) .unwrap(); assert_eq!((4, 2), result); } diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 9842611..b54f3d4 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -5,11 +5,11 @@ use std::os::raw::c_int; use std::path::Path; use std::ptr; use std::str; -use std::sync::atomic::{AtomicBool, Ordering, ATOMIC_BOOL_INIT}; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex, Once, ONCE_INIT}; use super::ffi; -use super::str_to_cstring; +use super::{str_for_sqlite, str_to_cstring}; use super::{Connection, InterruptHandle, OpenFlags, Result}; use crate::error::{error_from_handle, error_from_sqlite_code, Error}; use crate::raw_statement::RawStatement; @@ -207,20 +207,16 @@ impl InnerConnection { } pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result> { - if sql.len() >= ::std::i32::MAX as usize { - return Err(error_from_sqlite_code(ffi::SQLITE_TOOBIG, None)); - } let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() }; - let c_sql = str_to_cstring(sql)?; - let len_with_nul = (sql.len() + 1) as c_int; + let (c_sql, len, _) = str_for_sqlite(sql)?; let r = unsafe { if cfg!(feature = "unlock_notify") { let mut rc; loop { rc = ffi::sqlite3_prepare_v2( self.db(), - c_sql.as_ptr(), - len_with_nul, + c_sql, + len, &mut c_stmt, ptr::null_mut(), ); @@ -234,13 +230,7 @@ impl InnerConnection { } rc } else { - ffi::sqlite3_prepare_v2( - self.db(), - c_sql.as_ptr(), - len_with_nul, - &mut c_stmt, - ptr::null_mut(), - ) + ffi::sqlite3_prepare_v2(self.db(), c_sql, len, &mut c_stmt, ptr::null_mut()) } }; self.decode_result(r) @@ -292,7 +282,7 @@ impl Drop for InnerConnection { #[cfg(not(feature = "bundled"))] static SQLITE_VERSION_CHECK: Once = ONCE_INIT; #[cfg(not(feature = "bundled"))] -pub static BYPASS_VERSION_CHECK: AtomicBool = ATOMIC_BOOL_INIT; +pub static BYPASS_VERSION_CHECK: AtomicBool = AtomicBool::new(false); #[cfg(not(feature = "bundled"))] fn ensure_valid_sqlite_version() { @@ -339,7 +329,7 @@ rusqlite was built against SQLite {} but the runtime SQLite version is {}. To fi } static SQLITE_INIT: Once = ONCE_INIT; -pub static BYPASS_SQLITE_INIT: AtomicBool = ATOMIC_BOOL_INIT; +pub static BYPASS_SQLITE_INIT: AtomicBool = AtomicBool::new(false); fn ensure_safe_sqlite_threading_mode() -> Result<()> { // Ensure SQLite was compiled in thredsafe mode. diff --git a/src/lib.rs b/src/lib.rs index f1e7764..dfa5c22 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,7 +3,7 @@ //! //! ```rust //! use rusqlite::types::ToSql; -//! use rusqlite::{params, Connection}; +//! use rusqlite::{params, Connection, Result}; //! use time::Timespec; //! //! #[derive(Debug)] @@ -14,8 +14,8 @@ //! data: Option>, //! } //! -//! fn main() { -//! let conn = Connection::open_in_memory().unwrap(); +//! fn main() -> Result<()> { +//! let conn = Connection::open_in_memory()?; //! //! conn.execute( //! "CREATE TABLE person ( @@ -25,8 +25,7 @@ //! data BLOB //! )", //! params![], -//! ) -//! .unwrap(); +//! )?; //! let me = Person { //! id: 0, //! name: "Steven".to_string(), @@ -37,24 +36,22 @@ //! "INSERT INTO person (name, time_created, data) //! VALUES (?1, ?2, ?3)", //! params![me.name, me.time_created, me.data], -//! ) -//! .unwrap(); +//! )?; //! -//! let mut stmt = conn -//! .prepare("SELECT id, name, time_created, data FROM person") -//! .unwrap(); -//! let person_iter = stmt -//! .query_map(params![], |row| Person { -//! id: row.get(0), -//! name: row.get(1), -//! time_created: row.get(2), -//! data: row.get(3), +//! let mut stmt = conn.prepare("SELECT id, name, time_created, data FROM person")?; +//! let person_iter = stmt.query_map(params![], |row| { +//! Ok(Person { +//! id: row.get(0)?, +//! name: row.get(1)?, +//! time_created: row.get(2)?, +//! data: row.get(3)?, //! }) -//! .unwrap(); +//! })?; //! //! for person in person_iter { //! println!("Found person {:?}", person.unwrap()); //! } +//! Ok(()) //! } //! ``` #![allow(unknown_lints)] @@ -121,6 +118,7 @@ mod inner_connection; pub mod limits; #[cfg(feature = "load_extension")] mod load_extension_guard; +mod pragma; mod raw_statement; mod row; #[cfg(feature = "session")] @@ -242,6 +240,42 @@ fn str_to_cstring(s: &str) -> Result { Ok(CString::new(s)?) } +/// Returns `Ok((string ptr, len as c_int, SQLITE_STATIC | SQLITE_TRANSIENT))` +/// normally. +/// Returns errors if the string has embedded nuls or is too large for sqlite. +/// The `sqlite3_destructor_type` item is always `SQLITE_TRANSIENT` unless +/// the string was empty (in which case it's `SQLITE_STATIC`, and the ptr is +/// static). +fn str_for_sqlite(s: &str) -> Result<(*const c_char, c_int, ffi::sqlite3_destructor_type)> { + let len = len_as_c_int(s.len())?; + if memchr::memchr(0, s.as_bytes()).is_none() { + let (ptr, dtor_info) = if len != 0 { + (s.as_ptr() as *const c_char, ffi::SQLITE_TRANSIENT()) + } else { + // Return a pointer guaranteed to live forever + ("".as_ptr() as *const c_char, ffi::SQLITE_STATIC()) + }; + Ok((ptr, len, dtor_info)) + } else { + // There's an embedded nul, so we fabricate a NulError. + let e = CString::new(s); + Err(Error::NulError(e.unwrap_err())) + } +} + +// Helper to cast to c_int safely, returning the correct error type if the cast +// failed. +fn len_as_c_int(len: usize) -> Result { + if len >= (c_int::max_value() as usize) { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_TOOBIG), + None, + )) + } else { + Ok(len as c_int) + } +} + fn path_to_cstring(p: &Path) -> Result { let s = p.to_str().ok_or_else(|| Error::InvalidPath(p.to_owned()))?; str_to_cstring(s) @@ -478,7 +512,7 @@ impl Connection { where P: IntoIterator, P::Item: ToSql, - F: FnOnce(&Row<'_>) -> T, + F: FnOnce(&Row<'_>) -> Result, { let mut stmt = self.prepare(sql)?; stmt.query_row(params, f) @@ -500,12 +534,12 @@ impl Connection { /// or if the underlying SQLite call fails. pub fn query_row_named(&self, sql: &str, params: &[(&str, &dyn ToSql)], f: F) -> Result where - F: FnOnce(&Row<'_>) -> T, + F: FnOnce(&Row<'_>) -> Result, { let mut stmt = self.prepare(sql)?; let mut rows = stmt.query_named(params)?; - rows.get_expected_row().map(|r| f(&r)) + rows.get_expected_row().and_then(|r| f(&r)) } /// Convenience method to execute a query that is expected to return a @@ -521,7 +555,7 @@ impl Connection { /// conn.query_row_and_then( /// "SELECT value FROM preferences WHERE name='locale'", /// NO_PARAMS, - /// |row| row.get_checked(0), + /// |row| row.get(0), /// ) /// } /// ``` @@ -692,7 +726,7 @@ impl Connection { /// Return the number of rows modified, inserted or deleted by the most /// recently completed INSERT, UPDATE or DELETE statement on the database /// connection. - pub fn changes(&self) -> usize { + fn changes(&self) -> usize { self.db.borrow_mut().changes() } @@ -863,9 +897,9 @@ mod test { let tx2 = db2.transaction().unwrap(); // SELECT first makes sqlite lock with a shared lock - tx1.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| ()) + tx1.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| Ok(())) .unwrap(); - tx2.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| ()) + tx2.query_row("SELECT x FROM foo LIMIT 1", NO_PARAMS, |_| Ok(())) .unwrap(); tx1.execute("INSERT INTO foo VALUES(?1)", &[1]).unwrap(); @@ -1066,7 +1100,7 @@ mod test { let mut v = Vec::::new(); while let Some(row) = rows.next().unwrap() { - v.push(row.get(0)); + v.push(row.get(0).unwrap()); } assert_eq!(v, [3i32, 2, 1]); @@ -1077,7 +1111,7 @@ mod test { let mut v = Vec::::new(); while let Some(row) = rows.next().unwrap() { - v.push(row.get(0)); + v.push(row.get(0).unwrap()); } assert_eq!(v, [2i32, 1]); @@ -1130,7 +1164,7 @@ mod test { err => panic!("Unexpected error {}", err), } - let bad_query_result = db.query_row("NOT A PROPER QUERY; test123", NO_PARAMS, |_| ()); + let bad_query_result = db.query_row("NOT A PROPER QUERY; test123", NO_PARAMS, |_| Ok(())); assert!(bad_query_result.is_err()); } @@ -1404,7 +1438,7 @@ mod test { let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); let results: Result> = query - .query_and_then(NO_PARAMS, |row| row.get_checked(1)) + .query_and_then(NO_PARAMS, |row| row.get(1)) .unwrap() .collect(); @@ -1425,7 +1459,7 @@ mod test { let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); let bad_type: Result> = query - .query_and_then(NO_PARAMS, |row| row.get_checked(1)) + .query_and_then(NO_PARAMS, |row| row.get(1)) .unwrap() .collect(); @@ -1435,7 +1469,7 @@ mod test { } let bad_idx: Result> = query - .query_and_then(NO_PARAMS, |row| row.get_checked(3)) + .query_and_then(NO_PARAMS, |row| row.get(3)) .unwrap() .collect(); @@ -1459,9 +1493,7 @@ mod test { let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); let results: CustomResult> = query - .query_and_then(NO_PARAMS, |row| { - row.get_checked(1).map_err(CustomError::Sqlite) - }) + .query_and_then(NO_PARAMS, |row| row.get(1).map_err(CustomError::Sqlite)) .unwrap() .collect(); @@ -1482,9 +1514,7 @@ mod test { let mut query = db.prepare("SELECT x, y FROM foo ORDER BY x DESC").unwrap(); let bad_type: CustomResult> = query - .query_and_then(NO_PARAMS, |row| { - row.get_checked(1).map_err(CustomError::Sqlite) - }) + .query_and_then(NO_PARAMS, |row| row.get(1).map_err(CustomError::Sqlite)) .unwrap() .collect(); @@ -1494,9 +1524,7 @@ mod test { } let bad_idx: CustomResult> = query - .query_and_then(NO_PARAMS, |row| { - row.get_checked(3).map_err(CustomError::Sqlite) - }) + .query_and_then(NO_PARAMS, |row| row.get(3).map_err(CustomError::Sqlite)) .unwrap() .collect(); @@ -1527,7 +1555,7 @@ mod test { let query = "SELECT x, y FROM foo ORDER BY x DESC"; let results: CustomResult = db.query_row_and_then(query, NO_PARAMS, |row| { - row.get_checked(1).map_err(CustomError::Sqlite) + row.get(1).map_err(CustomError::Sqlite) }); assert_eq!(results.unwrap(), "hello"); @@ -1544,7 +1572,7 @@ mod test { let query = "SELECT x, y FROM foo ORDER BY x DESC"; let bad_type: CustomResult = db.query_row_and_then(query, NO_PARAMS, |row| { - row.get_checked(1).map_err(CustomError::Sqlite) + row.get(1).map_err(CustomError::Sqlite) }); match bad_type.unwrap_err() { @@ -1553,7 +1581,7 @@ mod test { } let bad_idx: CustomResult = db.query_row_and_then(query, NO_PARAMS, |row| { - row.get_checked(3).map_err(CustomError::Sqlite) + row.get(3).map_err(CustomError::Sqlite) }); match bad_idx.unwrap_err() { @@ -1580,7 +1608,8 @@ mod test { db.execute_batch(sql).unwrap(); db.query_row("SELECT * FROM foo", params![], |r| { - assert_eq!(2, r.column_count()) + assert_eq!(2, r.column_count()); + Ok(()) }) .unwrap(); } diff --git a/src/pragma.rs b/src/pragma.rs new file mode 100644 index 0000000..75cdb11 --- /dev/null +++ b/src/pragma.rs @@ -0,0 +1,432 @@ +//! Pragma helpers + +use std::ops::Deref; + +use crate::error::Error; +use crate::ffi; +use crate::types::{ToSql, ToSqlOutput, ValueRef}; +use crate::{Connection, DatabaseName, Result, Row, NO_PARAMS}; + +pub struct Sql { + buf: String, +} + +impl Sql { + pub fn new() -> Sql { + Sql { buf: String::new() } + } + + pub fn push_pragma( + &mut self, + schema_name: Option>, + pragma_name: &str, + ) -> Result<()> { + self.push_keyword("PRAGMA")?; + self.push_space(); + if let Some(schema_name) = schema_name { + self.push_schema_name(schema_name); + self.push_dot(); + } + self.push_keyword(pragma_name) + } + + pub fn push_keyword(&mut self, keyword: &str) -> Result<()> { + if !keyword.is_empty() && is_identifier(keyword) { + self.buf.push_str(keyword); + Ok(()) + } else { + Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Invalid keyword \"{}\"", keyword)), + )) + } + } + + pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) { + match schema_name { + DatabaseName::Main => self.buf.push_str("main"), + DatabaseName::Temp => self.buf.push_str("temp"), + DatabaseName::Attached(s) => self.push_identifier(s), + }; + } + + pub fn push_identifier(&mut self, s: &str) { + if is_identifier(s) { + self.buf.push_str(s); + } else { + self.wrap_and_escape(s, '"'); + } + } + + pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> { + let value = value.to_sql()?; + let value = match value { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + #[cfg(feature = "array")] + ToSqlOutput::Array(_) => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + match value { + ValueRef::Integer(i) => { + self.push_int(i); + } + ValueRef::Real(r) => { + self.push_real(r); + } + ValueRef::Text(s) => { + self.push_string_literal(s); + } + _ => { + return Err(Error::SqliteFailure( + ffi::Error::new(ffi::SQLITE_MISUSE), + Some(format!("Unsupported value \"{:?}\"", value)), + )); + } + }; + Ok(()) + } + + pub fn push_string_literal(&mut self, s: &str) { + self.wrap_and_escape(s, '\''); + } + + pub fn push_int(&mut self, i: i64) { + self.buf.push_str(&i.to_string()); + } + + pub fn push_real(&mut self, f: f64) { + self.buf.push_str(&f.to_string()); + } + + pub fn push_space(&mut self) { + self.buf.push(' '); + } + + pub fn push_dot(&mut self) { + self.buf.push('.'); + } + + pub fn push_equal_sign(&mut self) { + self.buf.push('='); + } + + pub fn open_brace(&mut self) { + self.buf.push('('); + } + + pub fn close_brace(&mut self) { + self.buf.push(')'); + } + + pub fn as_str(&self) -> &str { + &self.buf + } + + fn wrap_and_escape(&mut self, s: &str, quote: char) { + self.buf.push(quote); + let chars = s.chars(); + for ch in chars { + // escape `quote` by doubling it + if ch == quote { + self.buf.push(ch); + } + self.buf.push(ch) + } + self.buf.push(quote); + } +} + +impl Deref for Sql { + type Target = str; + + fn deref(&self) -> &str { + self.as_str() + } +} + +impl Connection { + /// Query the current value of `pragma_name`. + /// + /// Some pragmas will return multiple rows/values which cannot be retrieved + /// with this method. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT user_version FROM pragma_user_version;` + pub fn pragma_query_value( + &self, + schema_name: Option>, + pragma_name: &str, + f: F, + ) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + self.query_row(&query, NO_PARAMS, f) + } + + /// Query the current rows/values of `pragma_name`. + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_collation_list;` + pub fn pragma_query( + &self, + schema_name: Option>, + pragma_name: &str, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut query = Sql::new(); + query.push_pragma(schema_name, pragma_name)?; + let mut stmt = self.prepare(&query)?; + let mut rows = stmt.query(NO_PARAMS)?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(&row)?; + } + Ok(()) + } + + /// Query the current value(s) of `pragma_name` associated to + /// `pragma_value`. + /// + /// This method can be used with query-only pragmas which need an argument + /// (e.g. `table_info('one_tbl')`) or pragmas which returns value(s) + /// (e.g. `integrity_check`). + /// + /// Prefer [PRAGMA function](https://sqlite.org/pragma.html#pragfunc) introduced in SQLite 3.20: + /// `SELECT * FROM pragma_table_info(?);` + pub fn pragma( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: &dyn ToSql, + mut f: F, + ) -> Result<()> + where + F: FnMut(&Row<'_>) -> Result<()>, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.open_brace(); + sql.push_value(pragma_value)?; + sql.close_brace(); + let mut stmt = self.prepare(&sql)?; + let mut rows = stmt.query(NO_PARAMS)?; + while let Some(result_row) = rows.next()? { + let row = result_row; + f(&row)?; + } + Ok(()) + } + + /// Set a new value to `pragma_name`. + /// + /// Some pragmas will return the updated value which cannot be retrieved + /// with this method. + pub fn pragma_update( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: &dyn ToSql, + ) -> Result<()> { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(pragma_value)?; + self.execute_batch(&sql) + } + + /// Set a new value to `pragma_name` and return the updated value. + /// + /// Only few pragmas automatically return the updated value. + pub fn pragma_update_and_check( + &self, + schema_name: Option>, + pragma_name: &str, + pragma_value: &dyn ToSql, + f: F, + ) -> Result + where + F: FnOnce(&Row<'_>) -> Result, + { + let mut sql = Sql::new(); + sql.push_pragma(schema_name, pragma_name)?; + // The argument may be either in parentheses + // or it may be separated from the pragma name by an equal sign. + // The two syntaxes yield identical results. + sql.push_equal_sign(); + sql.push_value(pragma_value)?; + self.query_row(&sql, NO_PARAMS, f) + } +} + +fn is_identifier(s: &str) -> bool { + let chars = s.char_indices(); + for (i, ch) in chars { + if i == 0 { + if !is_identifier_start(ch) { + return false; + } + } else if !is_identifier_continue(ch) { + return false; + } + } + true +} + +fn is_identifier_start(c: char) -> bool { + (c >= 'A' && c <= 'Z') || c == '_' || (c >= 'a' && c <= 'z') || c > '\x7F' +} + +fn is_identifier_continue(c: char) -> bool { + c == '$' + || (c >= '0' && c <= '9') + || (c >= 'A' && c <= 'Z') + || c == '_' + || (c >= 'a' && c <= 'z') + || c > '\x7F' +} + +#[cfg(test)] +mod test { + use super::Sql; + use crate::pragma; + use crate::{Connection, DatabaseName}; + + #[test] + fn pragma_query_value() { + let db = Connection::open_in_memory().unwrap(); + let user_version: i32 = db + .pragma_query_value(None, "user_version", |row| row.get(0)) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + #[cfg(feature = "bundled")] + fn pragma_func_query_value() { + use crate::NO_PARAMS; + + let db = Connection::open_in_memory().unwrap(); + let user_version: i32 = db + .query_row( + "SELECT user_version FROM pragma_user_version", + NO_PARAMS, + |row| row.get(0), + ) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma_query_no_schema() { + let db = Connection::open_in_memory().unwrap(); + let mut user_version = -1; + db.pragma_query(None, "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + }) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma_query_with_schema() { + let db = Connection::open_in_memory().unwrap(); + let mut user_version = -1; + db.pragma_query(Some(DatabaseName::Main), "user_version", |row| { + user_version = row.get(0)?; + Ok(()) + }) + .unwrap(); + assert_eq!(0, user_version); + } + + #[test] + fn pragma() { + let db = Connection::open_in_memory().unwrap(); + let mut columns = Vec::new(); + db.pragma(None, "table_info", &"sqlite_master", |row| { + let column: String = row.get(1)?; + columns.push(column); + Ok(()) + }) + .unwrap(); + assert_eq!(5, columns.len()); + } + + #[test] + #[cfg(feature = "bundled")] + fn pragma_func() { + let db = Connection::open_in_memory().unwrap(); + let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?)").unwrap(); + let mut columns = Vec::new(); + let mut rows = table_info.query(&["sqlite_master"]).unwrap(); + + while let Some(row) = rows.next().unwrap() { + let row = row; + let column: String = row.get(1).unwrap(); + columns.push(column); + } + assert_eq!(5, columns.len()); + } + + #[test] + fn pragma_update() { + let db = Connection::open_in_memory().unwrap(); + db.pragma_update(None, "user_version", &1).unwrap(); + } + + #[test] + fn pragma_update_and_check() { + let db = Connection::open_in_memory().unwrap(); + let journal_mode: String = db + .pragma_update_and_check(None, "journal_mode", &"OFF", |row| row.get(0)) + .unwrap(); + assert_eq!("off", &journal_mode); + } + + #[test] + fn is_identifier() { + assert!(pragma::is_identifier("full")); + assert!(pragma::is_identifier("r2d2")); + assert!(!pragma::is_identifier("sp ce")); + assert!(!pragma::is_identifier("semi;colon")); + } + + #[test] + fn double_quote() { + let mut sql = Sql::new(); + sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#)); + assert_eq!(r#""schema"";--""#, sql.as_str()); + } + + #[test] + fn wrap_and_escape() { + let mut sql = Sql::new(); + sql.push_string_literal("value'; --"); + assert_eq!("'value''; --'", sql.as_str()); + } +} diff --git a/src/row.rs b/src/row.rs index 180e37d..72f8a7c 100644 --- a/src/row.rs +++ b/src/row.rs @@ -15,6 +15,23 @@ impl<'stmt> Rows<'stmt> { stmt.reset(); } } + + /// Attempt to get the next row from the query. Returns `Ok(Some(Row))` if + /// there is another row, `Err(...)` if there was an error + /// getting the next row, and `Ok(None)` if all rows have been retrieved. + /// + /// ## Note + /// + /// This interface is not compatible with Rust's `Iterator` trait, because + /// the lifetime of the returned row is tied to the lifetime of `self`. + /// This is a fallible "streaming iterator". For a more natural interface, + /// consider using `query_map` or `query_and_then` instead, which + /// return types that implement `Iterator`. + #[allow(clippy::should_implement_trait)] // cannot implement Iterator + pub fn next(&mut self) -> Result>> { + self.advance()?; + Ok((*self).get()) + } } impl<'stmt> Rows<'stmt> { @@ -47,7 +64,7 @@ pub struct MappedRows<'stmt, F> { impl<'stmt, T, F> MappedRows<'stmt, F> where - F: FnMut(&Row<'_>) -> T, + F: FnMut(&Row<'_>) -> Result, { pub(crate) fn new(rows: Rows<'stmt>, f: F) -> MappedRows<'stmt, F> { MappedRows { rows, map: f } @@ -56,7 +73,7 @@ where impl Iterator for MappedRows<'_, F> where - F: FnMut(&Row<'_>) -> T, + F: FnMut(&Row<'_>) -> Result, { type Item = Result; @@ -65,7 +82,7 @@ where self.rows .next() .transpose() - .map(|row_result| row_result.map(|row| (map)(&row))) + .map(|row_result| row_result.and_then(|row| (map)(&row))) } } @@ -145,16 +162,16 @@ impl<'stmt> Row<'stmt> { /// /// ## Failure /// - /// Panics if calling `row.get_checked(idx)` would return an error, + /// Panics if calling `row.get(idx)` would return an error, /// including: /// - /// * If the underlying SQLite column type is not a valid type as a - /// source for `T` + /// * If the underlying SQLite column type is not a valid type as a source + /// for `T` /// * If the underlying SQLite integral value is outside the range /// representable by `T` /// * If `idx` is outside the range of columns in the returned query - pub fn get(&self, idx: I) -> T { - self.get_checked(idx).unwrap() + pub fn get_unwrap(&self, idx: I) -> T { + self.get(idx).unwrap() } /// Get the value of a particular column of the result row. @@ -173,7 +190,7 @@ impl<'stmt> Row<'stmt> { /// If the result type is i128 (which requires the `i128_blob` feature to be /// enabled), and the underlying SQLite column is a blob whose size is not /// 16 bytes, `Error::InvalidColumnType` will also be returned. - pub fn get_checked(&self, idx: I) -> Result { + pub fn get(&self, idx: I) -> Result { let idx = idx.idx(self.stmt)?; let value = self.stmt.value_ref(idx); FromSql::column_result(value).map_err(|err| match err { @@ -193,7 +210,7 @@ impl<'stmt> Row<'stmt> { /// This `ValueRef` is valid only as long as this Row, which is enforced by /// it's lifetime. This means that while this method is completely safe, /// it can be somewhat difficult to use, and most callers will be better - /// served by `get` or `get_checked`. + /// served by `get` or `get`. /// /// ## Failure /// @@ -217,7 +234,7 @@ impl<'stmt> Row<'stmt> { /// This `ValueRef` is valid only as long as this Row, which is enforced by /// it's lifetime. This means that while this method is completely safe, /// it can be difficult to use, and most callers will be better served by - /// `get` or `get_checked`. + /// `get` or `get`. /// /// ## Failure /// diff --git a/src/statement.rs b/src/statement.rs index 70bc9a8..850b6d0 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -7,10 +7,9 @@ use std::slice::from_raw_parts; use std::{convert, fmt, mem, ptr, result, str}; use super::ffi; -use super::str_to_cstring; +use super::{len_as_c_int, str_for_sqlite, str_to_cstring}; use super::{ - AndThenRows, Connection, Error, FallibleStreamingIterator, MappedRows, RawStatement, Result, - Row, Rows, ValueRef, + AndThenRows, Connection, Error, MappedRows, RawStatement, Result, Row, Rows, ValueRef, }; use crate::types::{ToSql, ToSqlOutput}; #[cfg(feature = "array")] @@ -176,7 +175,7 @@ impl Statement<'_> { /// /// let mut names = Vec::new(); /// while let Some(row) = rows.next()? { - /// names.push(row.get(0)); + /// names.push(row.get(0)?); /// } /// /// Ok(names) @@ -267,7 +266,7 @@ impl Statement<'_> { where P: IntoIterator, P::Item: ToSql, - F: FnMut(&Row<'_>) -> T, + F: FnMut(&Row<'_>) -> Result, { let rows = self.query(params)?; Ok(MappedRows::new(rows, f)) @@ -306,7 +305,7 @@ impl Statement<'_> { f: F, ) -> Result> where - F: FnMut(&Row<'_>) -> T, + F: FnMut(&Row<'_>) -> Result, { let rows = self.query_named(params)?; Ok(MappedRows::new(rows, f)) @@ -354,7 +353,7 @@ impl Statement<'_> { /// fn get_names(conn: &Connection) -> Result> { /// let mut stmt = conn.prepare("SELECT name FROM people WHERE id = :id")?; /// let rows = - /// stmt.query_and_then_named(&[(":id", &"one")], |row| name_to_person(row.get(0)))?; + /// stmt.query_and_then_named(&[(":id", &"one")], |row| name_to_person(row.get(0)?))?; /// /// let mut persons = Vec::new(); /// for person_result in rows { @@ -410,11 +409,11 @@ impl Statement<'_> { where P: IntoIterator, P::Item: ToSql, - F: FnOnce(&Row<'_>) -> T, + F: FnOnce(&Row<'_>) -> Result, { let mut rows = self.query(params)?; - rows.get_expected_row().map(|r| f(&r)) + rows.get_expected_row().and_then(|r| f(&r)) } /// Consumes the statement. @@ -506,37 +505,19 @@ impl Statement<'_> { ValueRef::Integer(i) => unsafe { ffi::sqlite3_bind_int64(ptr, col as c_int, i) }, ValueRef::Real(r) => unsafe { ffi::sqlite3_bind_double(ptr, col as c_int, r) }, ValueRef::Text(s) => unsafe { - let length = s.len(); - if length > ::std::i32::MAX as usize { - ffi::SQLITE_TOOBIG - } else { - let c_str = str_to_cstring(s)?; - let destructor = if length > 0 { - ffi::SQLITE_TRANSIENT() - } else { - ffi::SQLITE_STATIC() - }; - ffi::sqlite3_bind_text( - ptr, - col as c_int, - c_str.as_ptr(), - length as c_int, - destructor, - ) - } + let (c_str, len, destructor) = str_for_sqlite(s)?; + ffi::sqlite3_bind_text(ptr, col as c_int, c_str, len, destructor) }, ValueRef::Blob(b) => unsafe { - let length = b.len(); - if length > ::std::i32::MAX as usize { - ffi::SQLITE_TOOBIG - } else if length == 0 { + let length = len_as_c_int(b.len())?; + if length == 0 { ffi::sqlite3_bind_zeroblob(ptr, col as c_int, 0) } else { ffi::sqlite3_bind_blob( ptr, col as c_int, b.as_ptr() as *const c_void, - length as c_int, + length, ffi::SQLITE_TRANSIENT(), ) } @@ -732,7 +713,7 @@ pub enum StatementStatus { #[cfg(test)] mod test { use crate::types::ToSql; - use crate::{Connection, Error, FallibleStreamingIterator, Result, NO_PARAMS}; + use crate::{Connection, Error, Result, NO_PARAMS}; #[test] fn test_execute_named() { @@ -798,8 +779,8 @@ mod test { .unwrap(); let mut rows = stmt.query_named(&[(":name", &"one")]).unwrap(); - let id: i32 = rows.next().unwrap().unwrap().get(0); - assert_eq!(1, id); + let id: Result = rows.next().unwrap().unwrap().get(0); + assert_eq!(Ok(1), id); } #[test] @@ -815,12 +796,14 @@ mod test { .prepare("SELECT id FROM test where name = :name") .unwrap(); let mut rows = stmt - .query_named(&[(":name", &"one")]).unwrap().map(|row| { - let id: i32 = row.get(0); - 2 * id - }); + .query_map_named(&[(":name", &"one")], |row| { + let id: Result = row.get(0); + id.map(|i| 2 * i) + }) + .unwrap(); - assert_eq!(Ok(Some(&2)), rows.next()); + let doubled_id: i32 = rows.next().unwrap().unwrap(); + assert_eq!(2, doubled_id); } #[test] @@ -838,7 +821,7 @@ mod test { .unwrap(); let mut rows = stmt .query_and_then_named(&[(":name", &"one")], |row| { - let id: i32 = row.get(0); + let id: i32 = row.get(0)?; if id == 1 { Ok(id) } else { diff --git a/src/trace.rs b/src/trace.rs index ddf537f..8af3f27 100644 --- a/src/trace.rs +++ b/src/trace.rs @@ -140,13 +140,13 @@ mod test { let mut db = Connection::open_in_memory().unwrap(); db.trace(Some(tracer)); { - let _ = db.query_row("SELECT ?", &[1i32], |_| {}); - let _ = db.query_row("SELECT ?", &["hello"], |_| {}); + let _ = db.query_row("SELECT ?", &[1i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", &["hello"], |_| Ok(())); } db.trace(None); { - let _ = db.query_row("SELECT ?", &[2i32], |_| {}); - let _ = db.query_row("SELECT ?", &["goodbye"], |_| {}); + let _ = db.query_row("SELECT ?", &[2i32], |_| Ok(())); + let _ = db.query_row("SELECT ?", &["goodbye"], |_| Ok(())); } let traced_stmts = TRACED_STMTS.lock().unwrap(); diff --git a/src/types/from_sql.rs b/src/types/from_sql.rs index c917b7a..ce679aa 100644 --- a/src/types/from_sql.rs +++ b/src/types/from_sql.rs @@ -151,7 +151,7 @@ impl FromSql for bool { impl FromSql for String { fn column_result(value: ValueRef<'_>) -> FromSqlResult { - value.as_str().map(|s| s.to_string()) + value.as_str().map(ToString::to_string) } } @@ -210,8 +210,7 @@ mod test { { for n in out_of_range { let err = db - .query_row("SELECT ?", &[n], |r| r.get_checked::<_, T>(0)) - .unwrap() + .query_row("SELECT ?", &[n], |r| r.get::<_, T>(0)) .unwrap_err(); match err { Error::IntegralValueOutOfRange(_, value) => assert_eq!(*n, value), diff --git a/src/types/mod.rs b/src/types/mod.rs index c704caa..48aa60a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -110,7 +110,7 @@ mod test { use time; use super::Value; - use crate::{Connection, Error, FallibleStreamingIterator, NO_PARAMS}; + use crate::{Connection, Error, NO_PARAMS}; use std::f64::EPSILON; use std::os::raw::{c_double, c_int}; @@ -207,16 +207,16 @@ mod test { { let row1 = rows.next().unwrap().unwrap(); - let s1: Option = row1.get(0); - let b1: Option> = row1.get(1); + let s1: Option = row1.get_unwrap(0); + let b1: Option> = row1.get_unwrap(1); assert_eq!(s.unwrap(), s1.unwrap()); assert!(b1.is_none()); } { let row2 = rows.next().unwrap().unwrap(); - let s2: Option = row2.get(0); - let b2: Option> = row2.get(1); + let s2: Option = row2.get_unwrap(0); + let b2: Option> = row2.get_unwrap(1); assert!(s2.is_none()); assert_eq!(b, b2); } @@ -246,102 +246,94 @@ mod test { let row = rows.next().unwrap().unwrap(); // check the correct types come back as expected - assert_eq!(vec![1, 2], row.get_checked::<_, Vec>(0).unwrap()); - assert_eq!("text", row.get_checked::<_, String>(1).unwrap()); - assert_eq!(1, row.get_checked::<_, c_int>(2).unwrap()); - assert!((1.5 - row.get_checked::<_, c_double>(3).unwrap()).abs() < EPSILON); - assert!(row.get_checked::<_, Option>(4).unwrap().is_none()); - assert!(row.get_checked::<_, Option>(4).unwrap().is_none()); - assert!(row.get_checked::<_, Option>(4).unwrap().is_none()); + assert_eq!(vec![1, 2], row.get::<_, Vec>(0).unwrap()); + assert_eq!("text", row.get::<_, String>(1).unwrap()); + assert_eq!(1, row.get::<_, c_int>(2).unwrap()); + assert!((1.5 - row.get::<_, c_double>(3).unwrap()).abs() < EPSILON); + assert!(row.get::<_, Option>(4).unwrap().is_none()); + assert!(row.get::<_, Option>(4).unwrap().is_none()); + assert!(row.get::<_, Option>(4).unwrap().is_none()); // check some invalid types // 0 is actually a blob (Vec) assert!(is_invalid_column_type( - row.get_checked::<_, c_int>(0).err().unwrap() + row.get::<_, c_int>(0).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, c_int>(0).err().unwrap() + row.get::<_, c_int>(0).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(0).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(0).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, i64>(0).err().unwrap() + row.get::<_, String>(0).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, c_double>(0).err().unwrap() + row.get::<_, time::Timespec>(0).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, String>(0).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, time::Timespec>(0).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Option>(0).err().unwrap() + row.get::<_, Option>(0).err().unwrap() )); // 1 is actually a text (String) assert!(is_invalid_column_type( - row.get_checked::<_, c_int>(1).err().unwrap() + row.get::<_, c_int>(1).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(1).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(1).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, i64>(1).err().unwrap() + row.get::<_, Vec>(1).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, c_double>(1).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Vec>(1).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Option>(1).err().unwrap() + row.get::<_, Option>(1).err().unwrap() )); // 2 is actually an integer assert!(is_invalid_column_type( - row.get_checked::<_, String>(2).err().unwrap() + row.get::<_, String>(2).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, Vec>(2).err().unwrap() + row.get::<_, Vec>(2).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, Option>(2).err().unwrap() + row.get::<_, Option>(2).err().unwrap() )); // 3 is actually a float (c_double) assert!(is_invalid_column_type( - row.get_checked::<_, c_int>(3).err().unwrap() + row.get::<_, c_int>(3).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(3).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, String>(3).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, i64>(3).err().unwrap() + row.get::<_, Vec>(3).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, String>(3).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Vec>(3).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Option>(3).err().unwrap() + row.get::<_, Option>(3).err().unwrap() )); // 4 is actually NULL assert!(is_invalid_column_type( - row.get_checked::<_, c_int>(4).err().unwrap() + row.get::<_, c_int>(4).err().unwrap() + )); + assert!(is_invalid_column_type(row.get::<_, i64>(4).err().unwrap())); + assert!(is_invalid_column_type( + row.get::<_, c_double>(4).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, i64>(4).err().unwrap() + row.get::<_, String>(4).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, c_double>(4).err().unwrap() + row.get::<_, Vec>(4).err().unwrap() )); assert!(is_invalid_column_type( - row.get_checked::<_, String>(4).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, Vec>(4).err().unwrap() - )); - assert!(is_invalid_column_type( - row.get_checked::<_, time::Timespec>(4).err().unwrap() + row.get::<_, time::Timespec>(4).err().unwrap() )); } @@ -360,19 +352,16 @@ mod test { let mut rows = stmt.query(NO_PARAMS).unwrap(); let row = rows.next().unwrap().unwrap(); - assert_eq!( - Value::Blob(vec![1, 2]), - row.get_checked::<_, Value>(0).unwrap() - ); + assert_eq!(Value::Blob(vec![1, 2]), row.get::<_, Value>(0).unwrap()); assert_eq!( Value::Text(String::from("text")), - row.get_checked::<_, Value>(1).unwrap() + row.get::<_, Value>(1).unwrap() ); - assert_eq!(Value::Integer(1), row.get_checked::<_, Value>(2).unwrap()); - match row.get_checked::<_, Value>(3).unwrap() { + assert_eq!(Value::Integer(1), row.get::<_, Value>(2).unwrap()); + match row.get::<_, Value>(3).unwrap() { Value::Real(val) => assert!((1.5 - val).abs() < EPSILON), x => panic!("Invalid Value {:?}", x), } - assert_eq!(Value::Null, row.get_checked::<_, Value>(4).unwrap()); + assert_eq!(Value::Null, row.get::<_, Value>(4).unwrap()); } } diff --git a/src/types/to_sql.rs b/src/types/to_sql.rs index 0d9a21d..ed5da10 100644 --- a/src/types/to_sql.rs +++ b/src/types/to_sql.rs @@ -229,7 +229,7 @@ mod test { let res = stmt .query_map(NO_PARAMS, |row| { - (row.get::<_, i128>(0), row.get::<_, String>(1)) + Ok((row.get::<_, i128>(0)?, row.get::<_, String>(1)?)) }) .unwrap() .collect::, _>>() diff --git a/src/vtab/csvtab.rs b/src/vtab/csvtab.rs index 173a4c2..8e733f4 100644 --- a/src/vtab/csvtab.rs +++ b/src/vtab/csvtab.rs @@ -346,7 +346,7 @@ impl From for Error { #[cfg(test)] mod test { use crate::vtab::csvtab; - use crate::{Connection, FallibleStreamingIterator, Result, NO_PARAMS}; + use crate::{Connection, Result, NO_PARAMS}; #[test] fn test_csv_module() { @@ -389,7 +389,7 @@ mod test { let mut rows = s.query(NO_PARAMS).unwrap(); let row = rows.next().unwrap().unwrap(); - assert_eq!(row.get::<_, i32>(0), 2); + assert_eq!(row.get_unwrap::<_, i32>(0), 2); } db.execute_batch("DROP TABLE vtab").unwrap(); }