diff --git a/.travis.yml b/.travis.yml index c96f8af..d520dab 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,10 +1,6 @@ language: rust sudo: false -env: - global: - secure: "FyGzHF0AIYdBcuM/2qIoABotx3MbNAlaHDzxPbbeUlVg64bnuib9G9K/qWve0a1BWCgv+8e/SbXZb7gt3JlUNE27aE4RZG4FEdtEpLYQp87Dc9d9HX0FwpUeFK3binsrtYl4WEBnIjQ3ICnUVey0E6GHEdkM+t5bWyJO5c4dJ30=" - script: - cargo build - cargo test diff --git a/Cargo.toml b/Cargo.toml index ccabe9d..3babbc3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ trace = [] [dependencies] time = "~0.1.0" bitflags = "0.7" +lru-cache = "0.0.7" libc = "~0.2" clippy = {version = "~0.0.58", optional = true} chrono = { version = "~0.2", optional = true } diff --git a/Changelog.md b/Changelog.md index c5007f3..df401e1 100644 --- a/Changelog.md +++ b/Changelog.md @@ -7,11 +7,20 @@ hazard related to the lack of these lifetime connections. We were already recommending the use of `query_map` and `query_and_then` over raw `query`; both of theose still return handles that implement `Iterator`. +* BREAKING CHANGE: `Transaction::savepoint()` now returns a `Savepoint` instead of another + `Transaction`. Unlike `Transaction`, `Savepoint`s can be rolled back while keeping the current + savepoint active. * BREAKING CHANGE: Creating transactions from a `Connection` or savepoints from a `Transaction` now take `&mut self` instead of `&self` to correctly represent that transactions within a connection are inherently nested. While a transaction is alive, the parent connection or transaction is unusable, so `Transaction` now implements `Deref`, giving access to `Connection`'s methods via the `Transaction` itself. +* BREAKING CHANGE: `Transaction::set_commit` and `Transaction::set_rollback` have been replaced + by `Transaction::set_drop_behavior`. +* Adds `Connection::prepare_cached`. `Connection` now keeps an internal cache of any statements + prepared via this method. The size of this cache defaults to 16 (`prepare_cached` will always + work but may re-prepare statements if more are prepared than the cache holds), and can be + controlled via `Connection::set_prepared_statement_cache_capacity`. * Adds `query_map_named` and `query_and_then_named` to `Statement`. * Adds `insert` convenience method to `Statement` which returns the row ID of an inserted row. * Adds `exists` convenience method returning whether a query finds one or more rows. diff --git a/benches/lib.rs b/benches/lib.rs new file mode 100644 index 0000000..92fddef --- /dev/null +++ b/benches/lib.rs @@ -0,0 +1,23 @@ +#![feature(test)] +extern crate test; + +extern crate rusqlite; + +use rusqlite::Connection; +use rusqlite::cache::StatementCache; +use test::Bencher; + +#[bench] +fn bench_no_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| db.prepare(sql).unwrap()); +} + +#[bench] +fn bench_cache(b: &mut Bencher) { + let db = Connection::open_in_memory().unwrap(); + let cache = StatementCache::new(&db, 15); + let sql = "SELECT 1, 'test', 3.14 UNION SELECT 2, 'exp', 2.71"; + b.iter(|| cache.get(sql).unwrap()); +} diff --git a/publish-ghp-docs.sh b/publish-ghp-docs.sh index f9eeb17..c8fc790 100755 --- a/publish-ghp-docs.sh +++ b/publish-ghp-docs.sh @@ -8,7 +8,7 @@ fi cd $(git rev-parse --show-toplevel) rm -rf target/doc/ -multirust run nightly cargo doc --no-deps --features "load_extension trace" +multirust run nightly cargo doc --no-deps --features "backup cache functions load_extension trace blob" echo '' > target/doc/index.html ghp-import target/doc git push origin gh-pages:gh-pages diff --git a/src/cache.rs b/src/cache.rs new file mode 100644 index 0000000..f5fade3 --- /dev/null +++ b/src/cache.rs @@ -0,0 +1,269 @@ +//! Prepared statements cache for faster execution. + +use std::cell::RefCell; +use std::ops::{Deref, DerefMut}; +use lru_cache::LruCache; +use {Result, Connection, Statement}; +use raw_statement::RawStatement; + +impl Connection { + /// Prepare a SQL statement for execution, returning a previously prepared (but + /// not currently in-use) statement if one is available. The returned statement + /// will be cached for reuse by future calls to `prepare_cached` once it is + /// dropped. + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// fn insert_new_people(conn: &Connection) -> Result<()> { + /// { + /// let mut stmt = try!(conn.prepare_cached("INSERT INTO People (name) VALUES (?)")); + /// try!(stmt.execute(&[&"Joe Smith"])); + /// } + /// { + /// // This will return the same underlying SQLite statement handle without + /// // having to prepare it again. + /// let mut stmt = try!(conn.prepare_cached("INSERT INTO People (name) VALUES (?)")); + /// try!(stmt.execute(&[&"Bob Jones"])); + /// } + /// Ok(()) + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if `sql` cannot be converted to a C-compatible string or if the + /// underlying SQLite call fails. + pub fn prepare_cached<'a>(&'a self, sql: &str) -> Result> { + self.cache.get(&self, sql) + } + + /// Set the maximum number of cached prepared statements this connection will hold. + /// By default, a connection will hold a relatively small number of cached statements. + /// If you need more, or know that you will not use cached statements, you can set + /// the capacity manually using this method. + pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) { + self.cache.set_capacity(capacity) + } +} + +/// Prepared statements LRU cache. +#[derive(Debug)] +pub struct StatementCache(RefCell>); + +/// Cacheable statement. +/// +/// Statement will return automatically to the cache by default. +/// If you want the statement to be discarded, call `discard()` on it. +pub struct CachedStatement<'conn> { + stmt: Option>, + cache: &'conn StatementCache, +} + +impl<'conn> Deref for CachedStatement<'conn> { + type Target = Statement<'conn>; + + fn deref(&self) -> &Statement<'conn> { + self.stmt.as_ref().unwrap() + } +} + +impl<'conn> DerefMut for CachedStatement<'conn> { + fn deref_mut(&mut self) -> &mut Statement<'conn> { + self.stmt.as_mut().unwrap() + } +} + +impl<'conn> Drop for CachedStatement<'conn> { + #[allow(unused_must_use)] + fn drop(&mut self) { + if let Some(stmt) = self.stmt.take() { + self.cache.cache_stmt(stmt.into()); + } + } +} + +impl<'conn> CachedStatement<'conn> { + fn new(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> { + CachedStatement { + stmt: Some(stmt), + cache: cache, + } + } + + pub fn discard(mut self) { + self.stmt = None; + } +} + +impl StatementCache { + /// Create a statement cache. + pub fn with_capacity(capacity: usize) -> StatementCache { + StatementCache(RefCell::new(LruCache::new(capacity))) + } + + fn set_capacity(&self, capacity: usize) { + self.0.borrow_mut().set_capacity(capacity) + } + + // Search the cache for a prepared-statement object that implements `sql`. + // If no such prepared-statement can be found, allocate and prepare a new one. + // + // # Failure + // + // Will return `Err` if no cached statement can be found and the underlying SQLite prepare + // call fails. + fn get<'conn>(&'conn self, + conn: &'conn Connection, + sql: &str) + -> Result> { + let mut cache = self.0.borrow_mut(); + let stmt = match cache.remove(sql) { + Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)), + None => conn.prepare(sql), + }; + stmt.map(|stmt| CachedStatement::new(stmt, self)) + } + + // Return a statement to the cache. + fn cache_stmt(&self, stmt: RawStatement) { + let mut cache = self.0.borrow_mut(); + stmt.clear_bindings(); + let sql = String::from_utf8_lossy(stmt.sql().to_bytes()).to_string(); + cache.insert(sql, stmt); + } +} + +#[cfg(test)] +mod test { + use Connection; + use super::StatementCache; + + impl StatementCache { + fn clear(&self) { + self.0.borrow_mut().clear(); + } + + fn len(&self) -> usize { + self.0.borrow().len() + } + + fn capacity(&self) -> usize { + self.0.borrow().capacity() + } + } + + #[test] + fn test_cache() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + let initial_capacity = cache.capacity(); + assert_eq!(0, cache.len()); + assert!(initial_capacity > 0); + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + } + assert_eq!(1, cache.len()); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + } + assert_eq!(1, cache.len()); + + cache.clear(); + assert_eq!(0, cache.len()); + assert_eq!(initial_capacity, cache.capacity()); + } + + #[test] + fn test_set_capacity() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + } + assert_eq!(1, cache.len()); + + db.set_prepared_statement_cache_capacity(0); + assert_eq!(0, cache.len()); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + } + assert_eq!(0, cache.len()); + + db.set_prepared_statement_cache_capacity(8); + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + } + assert_eq!(1, cache.len()); + } + + #[test] + fn test_discard() { + let db = Connection::open_in_memory().unwrap(); + let cache = &db.cache; + + let sql = "PRAGMA schema_version"; + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(0, cache.len()); + assert_eq!(0, + stmt.query(&[]).unwrap().get_expected_row().unwrap().get::(0)); + stmt.discard(); + } + assert_eq!(0, cache.len()); + } + + #[test] + fn test_ddl() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch(r#" + CREATE TABLE foo (x INT); + INSERT INTO foo VALUES (1); + "#) + .unwrap(); + + let sql = "SELECT * FROM foo"; + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!(1i32, + stmt.query_map(&[], |r| r.get(0)).unwrap().next().unwrap().unwrap()); + } + + db.execute_batch(r#" + ALTER TABLE foo ADD COLUMN y INT; + UPDATE foo SET y = 2; + "#) + .unwrap(); + + { + let mut stmt = db.prepare_cached(sql).unwrap(); + assert_eq!((1i32, 2i32), + stmt.query_map(&[], |r| (r.get(0), r.get(1))) + .unwrap() + .next() + .unwrap() + .unwrap()); + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 2e72d20..c3b565e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,7 @@ extern crate libc; extern crate libsqlite3_sys as ffi; +extern crate lru_cache; #[macro_use] extern crate bitflags; #[cfg(test)] @@ -76,24 +77,32 @@ use libc::{c_int, c_char}; use types::{ToSql, FromSql}; use error::{error_from_sqlite_code, error_from_handle}; +use raw_statement::RawStatement; +use cache::StatementCache; pub use transaction::{SqliteTransaction, Transaction, TransactionBehavior}; pub use error::{SqliteError, Error}; +pub use cache::CachedStatement; #[cfg(feature = "load_extension")] pub use load_extension_guard::{SqliteLoadExtensionGuard, LoadExtensionGuard}; pub mod types; mod transaction; +mod cache; mod named_params; mod error; mod convenient; +mod raw_statement; #[cfg(feature = "load_extension")]mod load_extension_guard; #[cfg(feature = "trace")]pub mod trace; #[cfg(feature = "backup")]pub mod backup; #[cfg(feature = "functions")]pub mod functions; #[cfg(feature = "blob")]pub mod blob; +// Number of cached prepared statements we'll hold on to. +const STATEMENT_CACHE_DEFAULT_CAPACITY: usize = 16; + /// Old name for `Result`. `SqliteResult` is deprecated. pub type SqliteResult = Result; @@ -146,6 +155,7 @@ pub type SqliteConnection = Connection; /// A connection to a SQLite database. pub struct Connection { db: RefCell, + cache: StatementCache, path: Option, } @@ -190,6 +200,7 @@ impl Connection { InnerConnection::open_with_flags(&c_path, flags).map(|db| { Connection { db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), path: Some(path.as_ref().to_path_buf()), } }) @@ -208,50 +219,12 @@ impl Connection { InnerConnection::open_with_flags(&c_memory, flags).map(|db| { Connection { db: RefCell::new(db), + cache: StatementCache::with_capacity(STATEMENT_CACHE_DEFAULT_CAPACITY), path: None, } }) } - /// Begin a new transaction with the default behavior (DEFERRED). - /// - /// The transaction defaults to rolling back when it is dropped. If you want the transaction to - /// commit, you must call `commit` or `set_commit`. - /// - /// ## Example - /// - /// ```rust,no_run - /// # use rusqlite::{Connection, Result}; - /// # fn do_queries_part_1(conn: &Connection) -> Result<()> { Ok(()) } - /// # fn do_queries_part_2(conn: &Connection) -> Result<()> { Ok(()) } - /// fn perform_queries(conn: &Connection) -> Result<()> { - /// let tx = try!(conn.transaction()); - /// - /// try!(do_queries_part_1(conn)); // tx causes rollback if this fails - /// try!(do_queries_part_2(conn)); // tx causes rollback if this fails - /// - /// tx.commit() - /// } - /// ``` - /// - /// # Failure - /// - /// Will return `Err` if the underlying SQLite call fails. - pub fn transaction(&mut self) -> Result { - Transaction::new(self, TransactionBehavior::Deferred) - } - - /// Begin a new transaction with a specified behavior. - /// - /// See `transaction`. - /// - /// # Failure - /// - /// Will return `Err` if the underlying SQLite call fails. - pub fn transaction_with_behavior(&mut self, behavior: TransactionBehavior) -> Result { - Transaction::new(self, behavior) - } - /// Convenience method to run multiple SQL statements (that cannot take any parameters). /// /// Uses [sqlite3_exec](http://www.sqlite.org/c3ref/exec.html) under the hood. @@ -686,7 +659,7 @@ impl InnerConnection { &mut c_stmt, ptr::null_mut()) }; - self.decode_result(r).map(|_| Statement::new(conn, c_stmt)) + self.decode_result(r).map(|_| Statement::new(conn, RawStatement::new(c_stmt))) } fn changes(&mut self) -> c_int { @@ -707,25 +680,23 @@ pub type SqliteStatement<'conn> = Statement<'conn>; /// A prepared statement. pub struct Statement<'conn> { conn: &'conn Connection, - stmt: *mut ffi::sqlite3_stmt, - column_count: c_int, + stmt: RawStatement, } impl<'conn> Statement<'conn> { - fn new(conn: &Connection, stmt: *mut ffi::sqlite3_stmt) -> Statement { + fn new(conn: &Connection, stmt: RawStatement) -> Statement { Statement { conn: conn, stmt: stmt, - column_count: unsafe { ffi::sqlite3_column_count(stmt) }, } } /// Get all the column names in the result set of the prepared statement. pub fn column_names(&self) -> Vec<&str> { - let n = self.column_count; + let n = self.column_count(); let mut cols = Vec::with_capacity(n as usize); for i in 0..n { - let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; + let slice = self.stmt.column_name(i); let s = str::from_utf8(slice.to_bytes()).unwrap(); cols.push(s); } @@ -734,7 +705,7 @@ impl<'conn> Statement<'conn> { /// Return the number of columns in the result set returned by the prepared statement. pub fn column_count(&self) -> i32 { - self.column_count + self.stmt.column_count() } /// Returns the column index in the result set for a given column name. @@ -744,10 +715,9 @@ impl<'conn> Statement<'conn> { /// Will return an `Error::InvalidColumnName` when there is no column with the specified `name`. pub fn column_index(&self, name: &str) -> Result { let bytes = name.as_bytes(); - let n = self.column_count; + let n = self.column_count(); for i in 0..n { - let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; - if bytes == slice.to_bytes() { + if bytes == self.stmt.column_name(i).to_bytes() { return Ok(i); } } @@ -778,18 +748,16 @@ impl<'conn> Statement<'conn> { /// Will return `Err` if binding parameters fails, the executed statement returns rows (in /// which case `query` should be used instead), or the underling SQLite call fails. pub fn execute(&mut self, params: &[&ToSql]) -> Result { - unsafe { - try!(self.bind_parameters(params)); - self.execute_() - } + try!(self.bind_parameters(params)); + self.execute_() } - unsafe fn execute_(&mut self) -> Result { - let r = ffi::sqlite3_step(self.stmt); - ffi::sqlite3_reset(self.stmt); + fn execute_(&mut self) -> Result { + let r = self.stmt.step(); + self.stmt.reset(); match r { ffi::SQLITE_DONE => { - if self.column_count == 0 { + if self.column_count() == 0 { Ok(self.conn.changes()) } else { Err(Error::ExecuteReturnedResults) @@ -828,10 +796,7 @@ impl<'conn> Statement<'conn> { /// /// Will return `Err` if binding parameters fails. pub fn query<'a>(&'a mut self, params: &[&ToSql]) -> Result> { - unsafe { - try!(self.bind_parameters(params)); - } - + try!(self.bind_parameters(params)); Ok(Rows::new(self)) } @@ -903,32 +868,39 @@ impl<'conn> Statement<'conn> { self.finalize_() } - unsafe fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { - assert!(params.len() as c_int == ffi::sqlite3_bind_parameter_count(self.stmt), + fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { + assert!(params.len() as c_int == self.stmt.bind_parameter_count(), "incorrect number of parameters to query(): expected {}, got {}", - ffi::sqlite3_bind_parameter_count(self.stmt), + self.stmt.bind_parameter_count(), params.len()); for (i, p) in params.iter().enumerate() { - try!(self.conn.decode_result(p.bind_parameter(self.stmt, (i + 1) as c_int))); + try!(unsafe { + self.conn.decode_result(p.bind_parameter(self.stmt.ptr(), (i + 1) as c_int)) + }); } Ok(()) } fn finalize_(&mut self) -> Result<()> { - let r = unsafe { ffi::sqlite3_finalize(self.stmt) }; - self.stmt = ptr::null_mut(); - self.conn.decode_result(r) + let mut stmt = RawStatement::new(ptr::null_mut()); + mem::swap(&mut stmt, &mut self.stmt); + self.conn.decode_result(stmt.finalize()) + } +} + +impl<'conn> Into for Statement<'conn> { + fn into(mut self) -> RawStatement { + let mut stmt = RawStatement::new(ptr::null_mut()); + mem::swap(&mut stmt, &mut self.stmt); + stmt } } impl<'conn> fmt::Debug for Statement<'conn> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let sql = unsafe { - let c_slice = CStr::from_ptr(ffi::sqlite3_sql(self.stmt)).to_bytes(); - str::from_utf8(c_slice) - }; + let sql = str::from_utf8(self.stmt.sql().to_bytes()); f.debug_struct("Statement") .field("conn", self.conn) .field("stmt", &self.stmt) @@ -1007,9 +979,7 @@ impl<'stmt> Rows<'stmt> { fn reset(&mut self) { if let Some(stmt) = self.stmt.take() { - unsafe { - ffi::sqlite3_reset(stmt.stmt); - } + stmt.stmt.reset(); } } @@ -1025,7 +995,7 @@ impl<'stmt> Rows<'stmt> { /// or `query_and_then` instead, which return types that implement `Iterator`. pub fn next<'a>(&'a mut self) -> Option>> { self.stmt.and_then(|stmt| { - match unsafe { ffi::sqlite3_step(stmt.stmt) } { + match stmt.stmt.step() { ffi::SQLITE_ROW => { Some(Ok(Row { stmt: stmt, @@ -1088,8 +1058,8 @@ impl<'a, 'stmt> Row<'a, 'stmt> { unsafe { let idx = try!(idx.idx(self.stmt)); - if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) { - FromSql::column_result(self.stmt.stmt, idx) + if T::column_has_valid_sqlite_type(self.stmt.stmt.ptr(), idx) { + FromSql::column_result(self.stmt.stmt.ptr(), idx) } else { Err(Error::InvalidColumnType) } @@ -1112,7 +1082,7 @@ pub trait RowIndex { impl RowIndex for i32 { #[inline] fn idx(&self, stmt: &Statement) -> Result { - if *self < 0 || *self >= stmt.column_count { + if *self < 0 || *self >= stmt.column_count() { Err(Error::InvalidColumnIndex(*self)) } else { Ok(*self) diff --git a/src/named_params.rs b/src/named_params.rs index 9bbb673..04ebcbd 100644 --- a/src/named_params.rs +++ b/src/named_params.rs @@ -2,8 +2,6 @@ use std::convert; use std::result; use libc::c_int; -use super::ffi; - use {Result, Error, Connection, Statement, MappedRows, AndThenRows, Rows, Row, str_to_cstring}; use types::ToSql; @@ -58,11 +56,7 @@ impl<'conn> Statement<'conn> { /// is valid but not a bound parameter of this statement. pub fn parameter_index(&self, name: &str) -> Result> { let c_name = try!(str_to_cstring(name)); - let c_index = unsafe { ffi::sqlite3_bind_parameter_index(self.stmt, c_name.as_ptr()) }; - Ok(match c_index { - 0 => None, // A zero is returned if no matching parameter is found. - n => Some(n), - }) + Ok(self.stmt.bind_parameter_index(&c_name)) } /// Execute the prepared statement with named parameter(s). If any parameters @@ -89,7 +83,7 @@ impl<'conn> Statement<'conn> { /// which case `query` should be used instead), or the underling SQLite call fails. pub fn execute_named(&mut self, params: &[(&str, &ToSql)]) -> Result { try!(self.bind_parameters_named(params)); - unsafe { self.execute_() } + self.execute_() } /// Execute the prepared statement with named parameter(s), returning a handle for the @@ -210,7 +204,7 @@ impl<'conn> Statement<'conn> { fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> { for &(name, value) in params { if let Some(i) = try!(self.parameter_index(name)) { - try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt, i) })); + try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt.ptr(), i) })); } else { return Err(Error::InvalidParameterName(name.into())); } diff --git a/src/raw_statement.rs b/src/raw_statement.rs new file mode 100644 index 0000000..9676d94 --- /dev/null +++ b/src/raw_statement.rs @@ -0,0 +1,70 @@ +use std::ffi::CStr; +use std::ptr; +use libc::c_int; +use super::ffi; + +// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. +#[derive(Debug)] +pub struct RawStatement(*mut ffi::sqlite3_stmt); + +impl RawStatement { + pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { + RawStatement(stmt) + } + + pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.0 + } + + pub fn column_count(&self) -> c_int { + unsafe { ffi::sqlite3_column_count(self.0) } + } + + pub fn column_name(&self, idx: c_int) -> &CStr { + unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.0, idx)) } + } + + pub fn step(&self) -> c_int { + unsafe { ffi::sqlite3_step(self.0) } + } + + pub fn reset(&self) -> c_int { + unsafe { ffi::sqlite3_reset(self.0) } + } + + pub fn bind_parameter_count(&self) -> c_int { + unsafe { ffi::sqlite3_bind_parameter_count(self.0) } + } + + pub fn bind_parameter_index(&self, name: &CStr) -> Option { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.0, name.as_ptr()) }; + match r { + 0 => None, + i => Some(i), + } + } + + pub fn clear_bindings(&self) -> c_int { + unsafe { ffi::sqlite3_clear_bindings(self.0) } + } + + pub fn sql(&self) -> &CStr { + unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.0)) } + } + + pub fn finalize(mut self) -> c_int { + self.finalize_() + } + + fn finalize_(&mut self) -> c_int { + let r = unsafe { ffi::sqlite3_finalize(self.0) }; + self.0 = ptr::null_mut(); + r + } +} + +impl Drop for RawStatement { + fn drop(&mut self) { + self.finalize_(); + } +} diff --git a/src/transaction.rs b/src/transaction.rs index a30fc30..97caa2e 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::ops::Deref; use {Result, Connection}; @@ -14,16 +13,30 @@ pub enum TransactionBehavior { Exclusive, } +/// Options for how a Transaction or Savepoint should behave when it is dropped. +#[derive(Copy,Clone,PartialEq,Eq)] +pub enum DropBehavior { + /// Roll back the changes. This is the default. + Rollback, + + /// Commit the changes. + Commit, + + /// Do not commit or roll back changes - this will leave the transaction or savepoint + /// open, so should be used with care. + Ignore, +} + /// Old name for `Transaction`. `SqliteTransaction` is deprecated. pub type SqliteTransaction<'conn> = Transaction<'conn>; -/// /// Represents a transaction on a database connection. /// /// ## Note /// -/// Transactions will roll back by default. Use the `set_commit` or `commit` methods to commit the -/// transaction. +/// Transactions will roll back by default. Use `commit` method to explicitly commit the +/// transaction, or use `set_drop_behavior` to change what happens when the transaction +/// is dropped. /// /// ## Example /// @@ -42,9 +55,39 @@ pub type SqliteTransaction<'conn> = Transaction<'conn>; /// ``` pub struct Transaction<'conn> { conn: &'conn Connection, + drop_behavior: DropBehavior, + committed: bool, +} + +/// Represents a savepoint on a database connection. +/// +/// ## Note +/// +/// Savepoints will roll back by default. Use `commit` method to explicitly commit the +/// savepoint, or use `set_drop_behavior` to change what happens when the savepoint +/// is dropped. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{Connection, Result}; +/// # fn do_queries_part_1(conn: &Connection) -> Result<()> { Ok(()) } +/// # fn do_queries_part_2(conn: &Connection) -> Result<()> { Ok(()) } +/// fn perform_queries(conn: &Connection) -> Result<()> { +/// let sp = try!(conn.savepoint()); +/// +/// try!(do_queries_part_1(conn)); // sp causes rollback if this fails +/// try!(do_queries_part_2(conn)); // sp causes rollback if this fails +/// +/// sp.commit() +/// } +/// ``` +pub struct Savepoint<'conn> { + conn: &'conn Connection, + name: String, depth: u32, - commit: bool, - finished: bool, + drop_behavior: DropBehavior, + committed: bool, } impl<'conn> Transaction<'conn> { @@ -58,9 +101,8 @@ impl<'conn> Transaction<'conn> { conn.execute_batch(query).map(move |_| { Transaction { conn: conn, - depth: 0, - commit: false, - finished: false, + drop_behavior: DropBehavior::Rollback, + committed: false, } }) } @@ -91,36 +133,23 @@ impl<'conn> Transaction<'conn> { /// tx.commit() /// } /// ``` - pub fn savepoint(&mut self) -> Result { - let new_depth = self.depth + 1; - self.conn.execute_batch(&format!("SAVEPOINT sp{}", new_depth)).map(|_| { - Transaction { - conn: self.conn, - depth: new_depth, - commit: false, - finished: false, - } - }) + pub fn savepoint(&mut self) -> Result { + Savepoint::with_depth(self.conn, 1) } - /// Returns whether or not the transaction is currently set to commit. - pub fn will_commit(&self) -> bool { - self.commit + /// Create a new savepoint with a custom savepoint name. See `savepoint()`. + pub fn savepoint_with_name>(&mut self, name: T) -> Result { + Savepoint::with_depth_and_name(self.conn, 1, name) } - /// Returns whether or not the transaction is currently set to rollback. - pub fn will_rollback(&self) -> bool { - !self.commit + /// Get the current setting for what happens to the transaction when it is dropped. + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior } - /// Set the transaction to commit at its completion. - pub fn set_commit(&mut self) { - self.commit = true - } - - /// Set the transaction to rollback at its completion. - pub fn set_rollback(&mut self) { - self.commit = false + /// Configure the transaction to perform the specified action when it is dropped. + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior } /// A convenience method which consumes and commits a transaction. @@ -129,13 +158,8 @@ impl<'conn> Transaction<'conn> { } fn commit_(&mut self) -> Result<()> { - self.finished = true; - let sql = if self.depth == 0 { - Cow::Borrowed("COMMIT") - } else { - Cow::Owned(format!("RELEASE sp{}", self.depth)) - }; - self.conn.execute_batch(&sql) + self.committed = true; + self.conn.execute_batch("COMMIT") } /// A convenience method which consumes and rolls back a transaction. @@ -144,17 +168,12 @@ impl<'conn> Transaction<'conn> { } fn rollback_(&mut self) -> Result<()> { - self.finished = true; - let sql = if self.depth == 0 { - Cow::Borrowed("ROLLBACK") - } else { - Cow::Owned(format!("ROLLBACK TO sp{}", self.depth)) - }; - self.conn.execute_batch(&sql) + self.committed = true; + self.conn.execute_batch("ROLLBACK") } /// Consumes the transaction, committing or rolling back according to the current setting - /// (see `will_commit`, `will_rollback`). + /// (see `drop_behavior`). /// /// Functionally equivalent to the `Drop` implementation, but allows callers to see any /// errors that occur. @@ -163,10 +182,13 @@ impl<'conn> Transaction<'conn> { } fn finish_(&mut self) -> Result<()> { - match (self.finished, self.commit) { - (true, _) => Ok(()), - (false, true) => self.commit_(), - (false, false) => self.rollback_(), + if self.committed { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_(), + DropBehavior::Rollback => self.rollback_(), + DropBehavior::Ignore => Ok(()), } } } @@ -186,10 +208,196 @@ impl<'conn> Drop for Transaction<'conn> { } } +impl<'conn> Savepoint<'conn> { + fn with_depth_and_name>(conn: &Connection, depth: u32, name: T) -> Result { + let name = name.into(); + conn.execute_batch(&format!("SAVEPOINT {}", name)).map(|_| { + Savepoint { + conn: conn, + name: name, + depth: depth, + drop_behavior: DropBehavior::Rollback, + committed: false, + } + }) + } + + fn with_depth(conn: &Connection, depth: u32) -> Result { + let name = format!("_rusqlite_sp_{}", depth); + Savepoint::with_depth_and_name(conn, depth, name) + } + + /// Begin a new savepoint. Can be nested. + pub fn new(conn: &mut Connection) -> Result { + Savepoint::with_depth(conn, 0) + } + + /// Begin a new savepoint with a user-provided savepoint name. + pub fn with_name>(conn: &mut Connection, name: T) -> Result { + Savepoint::with_depth_and_name(conn, 0, name) + } + + /// Begin a nested savepoint. + pub fn savepoint(&mut self) -> Result { + Savepoint::with_depth(self.conn, self.depth + 1) + } + + /// Begin a nested savepoint with a user-provided savepoint name. + pub fn savepoint_with_name>(&mut self, name: T) -> Result { + Savepoint::with_depth_and_name(self.conn, self.depth + 1, name) + } + + /// Get the current setting for what happens to the savepoint when it is dropped. + pub fn drop_behavior(&self) -> DropBehavior { + self.drop_behavior + } + + /// Configure the savepoint to perform the specified action when it is dropped. + pub fn set_drop_behavior(&mut self, drop_behavior: DropBehavior) { + self.drop_behavior = drop_behavior + } + + /// A convenience method which consumes and commits a savepoint. + pub fn commit(mut self) -> Result<()> { + self.commit_() + } + + fn commit_(&mut self) -> Result<()> { + self.committed = true; + self.conn.execute_batch(&format!("RELEASE {}", self.name)) + } + + /// A convenience method which rolls back a savepoint. + /// + /// ## Note + /// + /// Unlike `Transaction`s, savepoints remain active after they have been rolled back, + /// and can be rolled back again or committed. + pub fn rollback(&mut self) -> Result<()> { + self.conn.execute_batch(&format!("ROLLBACK TO {}", self.name)) + } + + /// Consumes the savepoint, committing or rolling back according to the current setting + /// (see `drop_behavior`). + /// + /// Functionally equivalent to the `Drop` implementation, but allows callers to see any + /// errors that occur. + pub fn finish(mut self) -> Result<()> { + self.finish_() + } + + fn finish_(&mut self) -> Result<()> { + if self.committed { + return Ok(()); + } + match self.drop_behavior() { + DropBehavior::Commit => self.commit_(), + DropBehavior::Rollback => self.rollback(), + DropBehavior::Ignore => Ok(()), + } + } +} + +impl<'conn> Deref for Savepoint<'conn> { + type Target = Connection; + + fn deref(&self) -> &Connection { + self.conn + } +} + +#[allow(unused_must_use)] +impl<'conn> Drop for Savepoint<'conn> { + fn drop(&mut self) { + self.finish_(); + } +} + +impl Connection { + /// Begin a new transaction with the default behavior (DEFERRED). + /// + /// The transaction defaults to rolling back when it is dropped. If you want the transaction to + /// commit, you must call `commit` or `set_drop_behavior(DropBehavior::Commit)`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &Connection) -> Result<()> { + /// let tx = try!(conn.transaction()); + /// + /// try!(do_queries_part_1(conn)); // tx causes rollback if this fails + /// try!(do_queries_part_2(conn)); // tx causes rollback if this fails + /// + /// tx.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn transaction(&mut self) -> Result { + Transaction::new(self, TransactionBehavior::Deferred) + } + + /// Begin a new transaction with a specified behavior. + /// + /// See `transaction`. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn transaction_with_behavior(&mut self, behavior: TransactionBehavior) -> Result { + Transaction::new(self, behavior) + } + + /// Begin a new savepoint with the default behavior (DEFERRED). + /// + /// The savepoint defaults to rolling back when it is dropped. If you want the savepoint to + /// commit, you must call `commit` or `set_drop_behavior(DropBehavior::Commit)`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{Connection, Result}; + /// # fn do_queries_part_1(conn: &Connection) -> Result<()> { Ok(()) } + /// # fn do_queries_part_2(conn: &Connection) -> Result<()> { Ok(()) } + /// fn perform_queries(conn: &Connection) -> Result<()> { + /// let sp = try!(conn.savepoint()); + /// + /// try!(do_queries_part_1(conn)); // sp causes rollback if this fails + /// try!(do_queries_part_2(conn)); // sp causes rollback if this fails + /// + /// sp.commit() + /// } + /// ``` + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn savepoint(&mut self) -> Result { + Savepoint::new(self) + } + + /// Begin a new savepoint with a specified name. + /// + /// See `savepoint`. + /// + /// # Failure + /// + /// Will return `Err` if the underlying SQLite call fails. + pub fn savepoint_with_name>(&mut self, name: T) -> Result { + Savepoint::with_name(self, name) + } +} + #[cfg(test)] #[cfg_attr(feature="clippy", allow(similar_names))] mod test { use Connection; + use super::DropBehavior; fn checked_memory_handle() -> Connection { let db = Connection::open_in_memory().unwrap(); @@ -208,7 +416,7 @@ mod test { { let mut tx = db.transaction().unwrap(); tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); - tx.set_commit() + tx.set_drop_behavior(DropBehavior::Commit) } { let tx = db.transaction().unwrap(); @@ -221,35 +429,36 @@ mod test { fn test_explicit_rollback_commit() { let mut db = checked_memory_handle(); { - let tx = db.transaction().unwrap(); - tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); - tx.rollback().unwrap(); - } - { - let tx = db.transaction().unwrap(); - tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + let mut tx = db.transaction().unwrap(); + { + let mut sp = tx.savepoint().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + sp.rollback().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + sp.commit().unwrap(); + } tx.commit().unwrap(); } { let tx = db.transaction().unwrap(); - assert_eq!(2i32, + tx.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); + tx.commit().unwrap(); + } + { + let tx = db.transaction().unwrap(); + assert_eq!(6i32, tx.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); } } #[test] fn test_savepoint() { - fn assert_current_sum(x: i32, conn: &Connection) { - let i = conn.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap(); - assert_eq!(x, i); - } - let mut db = checked_memory_handle(); { let mut tx = db.transaction().unwrap(); tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); assert_current_sum(1, &tx); - tx.set_commit(); + tx.set_drop_behavior(DropBehavior::Commit); { let mut sp1 = tx.savepoint().unwrap(); sp1.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); @@ -275,4 +484,64 @@ mod test { } assert_current_sum(1, &db); } + + #[test] + fn test_ignore_drop_behavior() { + let mut db = checked_memory_handle(); + + let mut tx = db.transaction().unwrap(); + { + let mut sp1 = tx.savepoint().unwrap(); + insert(1, &sp1); + sp1.rollback().unwrap(); + insert(2, &sp1); + { + let mut sp2 = sp1.savepoint().unwrap(); + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(4, &sp2); + } + assert_current_sum(6, &sp1); + sp1.commit().unwrap(); + } + assert_current_sum(6, &tx); + } + + #[test] + fn test_savepoint_names() { + let mut db = checked_memory_handle(); + + { + let mut sp1 = db.savepoint_with_name("my_sp").unwrap(); + insert(1, &sp1); + assert_current_sum(1, &sp1); + { + let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap(); + sp2.set_drop_behavior(DropBehavior::Commit); + insert(2, &sp2); + assert_current_sum(3, &sp2); + sp2.rollback().unwrap(); + assert_current_sum(1, &sp2); + insert(4, &sp2); + } + assert_current_sum(5, &sp1); + sp1.rollback().unwrap(); + { + let mut sp2 = sp1.savepoint_with_name("my_sp").unwrap(); + sp2.set_drop_behavior(DropBehavior::Ignore); + insert(8, &sp2); + } + assert_current_sum(8, &sp1); + sp1.commit().unwrap(); + } + assert_current_sum(8, &db); + } + + fn insert(x: i32, conn: &Connection) { + conn.execute("INSERT INTO foo VALUES(?)", &[&x]).unwrap(); + } + + fn assert_current_sum(x: i32, conn: &Connection) { + let i = conn.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap(); + assert_eq!(x, i); + } }