diff --git a/src/transaction.rs b/src/transaction.rs index a30fc30..79bb7d9 100644 --- a/src/transaction.rs +++ b/src/transaction.rs @@ -138,13 +138,8 @@ impl<'conn> Transaction<'conn> { self.conn.execute_batch(&sql) } - /// A convenience method which consumes and rolls back a transaction. - pub fn rollback(mut self) -> Result<()> { - self.rollback_() - } - - fn rollback_(&mut self) -> Result<()> { - self.finished = true; + /// A convenience method which rolls back a transaction. + pub fn rollback(&mut self) -> Result<()> { let sql = if self.depth == 0 { Cow::Borrowed("ROLLBACK") } else { @@ -166,7 +161,7 @@ impl<'conn> Transaction<'conn> { match (self.finished, self.commit) { (true, _) => Ok(()), (false, true) => self.commit_(), - (false, false) => self.rollback_(), + (false, false) => self.rollback(), } } } @@ -221,18 +216,24 @@ mod test { fn test_explicit_rollback_commit() { let mut db = checked_memory_handle(); { - let tx = db.transaction().unwrap(); - tx.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); - tx.rollback().unwrap(); - } - { - let tx = db.transaction().unwrap(); - tx.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + let mut tx = db.transaction().unwrap(); + { + let mut sp = tx.savepoint().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(1)").unwrap(); + sp.rollback().unwrap(); + sp.execute_batch("INSERT INTO foo VALUES(2)").unwrap(); + sp.commit().unwrap(); + } tx.commit().unwrap(); } { let tx = db.transaction().unwrap(); - assert_eq!(2i32, + tx.execute_batch("INSERT INTO foo VALUES(4)").unwrap(); + tx.commit().unwrap(); + } + { + let tx = db.transaction().unwrap(); + assert_eq!(6i32, tx.query_row("SELECT SUM(x) FROM foo", &[], |r| r.get(0)).unwrap()); } }