diff --git a/src/lib.rs b/src/lib.rs index 09e68bb..8708869 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -237,7 +237,7 @@ impl Connection { /// # Failure /// /// Will return `Err` if the underlying SQLite call fails. - pub fn transaction(&self) -> Result { + pub fn transaction(&mut self) -> Result { Transaction::new(self, TransactionBehavior::Deferred) } @@ -248,7 +248,7 @@ impl Connection { /// # Failure /// /// Will return `Err` if the underlying SQLite call fails. - pub fn transaction_with_behavior(&self, behavior: TransactionBehavior) -> Result { + pub fn transaction_with_behavior(&mut self, behavior: TransactionBehavior) -> Result { Transaction::new(self, behavior) } diff --git a/src/transaction.rs b/src/transaction.rs index e72add4..333f7a0 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -1,3 +1,4 @@ +use std::ops::Deref; use {Result, Connection}; /// Old name for `TransactionBehavior`. `SqliteTransactionBehavior` is deprecated. @@ -47,13 +48,13 @@ pub struct Transaction<'conn> { impl<'conn> Transaction<'conn> { /// Begin a new transaction. Cannot be nested; see `savepoint` for nested transactions. - pub fn new(conn: &Connection, behavior: TransactionBehavior) -> Result { + pub fn new(conn: &mut Connection, behavior: TransactionBehavior) -> Result { let query = match behavior { TransactionBehavior::Deferred => "BEGIN DEFERRED", TransactionBehavior::Immediate => "BEGIN IMMEDIATE", TransactionBehavior::Exclusive => "BEGIN EXCLUSIVE", }; - conn.execute_batch(query).map(|_| { + conn.execute_batch(query).map(move |_| { Transaction { conn: conn, depth: 0, @@ -89,7 +90,7 @@ impl<'conn> Transaction<'conn> { /// tx.commit() /// } /// ``` - pub fn savepoint(&self) -> Result { + pub fn savepoint(&mut self) -> Result { self.conn.execute_batch("SAVEPOINT sp").map(|_| { Transaction { conn: self.conn, @@ -166,6 +167,14 @@ impl<'conn> Transaction<'conn> { } } +impl<'conn> Deref for Transaction<'conn> { + type Target = Connection; + + fn deref(&self) -> &Connection { + self.conn + } +} + #[allow(unused_must_use)] impl<'conn> Drop for Transaction<'conn> { fn drop(&mut self) { @@ -186,62 +195,62 @@ mod test { #[test] fn test_drop() { - let db = checked_memory_handle(); + let mut db = checked_memory_handle(); { - let _tx = db.transaction().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + let tx = db.transaction().unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); // default: rollback } { let mut tx = db.transaction().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); tx.set_commit() } { - let _tx = db.transaction().unwrap(); + let tx = db.transaction().unwrap(); assert_eq!(2i32, - db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); + tx.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); } } #[test] fn test_explicit_rollback_commit() { - let db = checked_memory_handle(); + let mut db = checked_memory_handle(); { let tx = db.transaction().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); tx.rollback().unwrap(); } { let tx = db.transaction().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); tx.commit().unwrap(); } { - let _tx = db.transaction().unwrap(); + let tx = db.transaction().unwrap(); assert_eq!(2i32, - db.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); + tx.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); } } #[test] fn test_savepoint() { - let db = checked_memory_handle(); + let mut db = checked_memory_handle(); { let mut tx = db.transaction().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); tx.set_commit(); { let mut sp1 = tx.savepoint().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + sp1.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); sp1.set_commit(); { - let sp2 = sp1.savepoint().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); + let mut sp2 = sp1.savepoint().unwrap(); + sp2.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); // will rollback sp2 { let sp3 = sp2.savepoint().unwrap(); - db.execute_batch("INSERT INTO foo VALUES(8)").unwrap(); + sp3.execute_batch("INSERT INTO foo VALUES(8)").unwrap(); sp3.commit().unwrap(); // committed sp3, but will be erased by sp2 rollback }