diff --git a/src/lib.rs b/src/lib.rs index 1eb00bb..62b9106 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,6 +76,7 @@ use libc::{c_int, c_char}; use types::{ToSql, FromSql}; use error::{error_from_sqlite_code, error_from_handle}; +use raw_statement::RawStatement; pub use transaction::{SqliteTransaction, Transaction, TransactionBehavior}; pub use error::{SqliteError, Error}; @@ -88,6 +89,7 @@ mod transaction; mod named_params; mod error; mod convenient; +mod raw_statement; #[cfg(feature = "load_extension")]mod load_extension_guard; #[cfg(feature = "trace")]pub mod trace; #[cfg(feature = "backup")]pub mod backup; @@ -687,7 +689,7 @@ impl InnerConnection { &mut c_stmt, ptr::null_mut()) }; - self.decode_result(r).map(|_| Statement::new(conn, c_stmt)) + self.decode_result(r).map(|_| Statement::new(conn, RawStatement::new(c_stmt))) } fn changes(&mut self) -> c_int { @@ -708,16 +710,17 @@ pub type SqliteStatement<'conn> = Statement<'conn>; /// A prepared statement. pub struct Statement<'conn> { conn: &'conn Connection, - stmt: *mut ffi::sqlite3_stmt, + stmt: RawStatement, column_count: c_int, } impl<'conn> Statement<'conn> { - fn new(conn: &Connection, stmt: *mut ffi::sqlite3_stmt) -> Statement { + fn new(conn: &Connection, stmt: RawStatement) -> Statement { + let column_count = stmt.column_count(); Statement { conn: conn, stmt: stmt, - column_count: unsafe { ffi::sqlite3_column_count(stmt) }, + column_count: column_count, } } @@ -726,7 +729,7 @@ impl<'conn> Statement<'conn> { let n = self.column_count; let mut cols = Vec::with_capacity(n as usize); for i in 0..n { - let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; + let slice = self.stmt.column_name(i); let s = str::from_utf8(slice.to_bytes()).unwrap(); cols.push(s); } @@ -747,8 +750,7 @@ impl<'conn> Statement<'conn> { let bytes = name.as_bytes(); let n = self.column_count; for i in 0..n { - let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; - if bytes == slice.to_bytes() { + if bytes == self.stmt.column_name(i).to_bytes() { return Ok(i); } } @@ -779,15 +781,13 @@ impl<'conn> Statement<'conn> { /// Will return `Err` if binding parameters fails, the executed statement returns rows (in /// which case `query` should be used instead), or the underling SQLite call fails. pub fn execute(&mut self, params: &[&ToSql]) -> Result { - unsafe { - try!(self.bind_parameters(params)); - self.execute_() - } + try!(self.bind_parameters(params)); + self.execute_() } - unsafe fn execute_(&mut self) -> Result { - let r = ffi::sqlite3_step(self.stmt); - ffi::sqlite3_reset(self.stmt); + fn execute_(&mut self) -> Result { + let r = self.stmt.step(); + self.stmt.reset(); match r { ffi::SQLITE_DONE => { if self.column_count == 0 { @@ -825,10 +825,7 @@ impl<'conn> Statement<'conn> { /// /// Will return `Err` if binding parameters fails. pub fn query<'a>(&'a mut self, params: &[&ToSql]) -> Result> { - unsafe { - try!(self.bind_parameters(params)); - } - + try!(self.bind_parameters(params)); Ok(Rows::new(self)) } @@ -889,14 +886,16 @@ impl<'conn> Statement<'conn> { self.finalize_() } - unsafe fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { - assert!(params.len() as c_int == ffi::sqlite3_bind_parameter_count(self.stmt), + fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { + assert!(params.len() as c_int == self.stmt.bind_parameter_count(), "incorrect number of parameters to query(): expected {}, got {}", - ffi::sqlite3_bind_parameter_count(self.stmt), + self.stmt.bind_parameter_count(), params.len()); for (i, p) in params.iter().enumerate() { - try!(self.conn.decode_result(p.bind_parameter(self.stmt, (i + 1) as c_int))); + try!(unsafe { + self.conn.decode_result(p.bind_parameter(self.stmt.ptr(), (i + 1) as c_int)) + }); } Ok(()) @@ -904,32 +903,25 @@ impl<'conn> Statement<'conn> { #[cfg(feature = "cache")] fn clear_bindings(&mut self) { - unsafe { - ffi::sqlite3_clear_bindings(self.stmt); - }; + self.stmt.clear_bindings(); } #[cfg(feature = "cache")] fn eq(&self, sql: &str) -> bool { - unsafe { - let c_slice = CStr::from_ptr(ffi::sqlite3_sql(self.stmt)).to_bytes(); - str::from_utf8(c_slice).unwrap().eq(sql) - } + let c_slice = self.stmt.sql().to_bytes(); + sql.as_bytes().eq(c_slice) } fn finalize_(&mut self) -> Result<()> { - let r = unsafe { ffi::sqlite3_finalize(self.stmt) }; - self.stmt = ptr::null_mut(); - self.conn.decode_result(r) + let mut stmt = RawStatement::new(ptr::null_mut()); + mem::swap(&mut stmt, &mut self.stmt); + self.conn.decode_result(stmt.finalize()) } } impl<'conn> fmt::Debug for Statement<'conn> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - let sql = unsafe { - let c_slice = CStr::from_ptr(ffi::sqlite3_sql(self.stmt)).to_bytes(); - str::from_utf8(c_slice) - }; + let sql = str::from_utf8(self.stmt.sql().to_bytes()); f.debug_struct("Statement") .field("conn", self.conn) .field("stmt", &self.stmt) @@ -1039,9 +1031,7 @@ impl<'stmt> Rows<'stmt> { fn reset(&mut self) { if let Some(stmt) = self.stmt.take() { - unsafe { - ffi::sqlite3_reset(stmt.stmt); - } + stmt.stmt.reset(); } } } @@ -1051,7 +1041,7 @@ impl<'stmt> Iterator for Rows<'stmt> { fn next(&mut self) -> Option>> { self.stmt.and_then(|stmt| { - match unsafe { ffi::sqlite3_step(stmt.stmt) } { + match stmt.stmt.step() { ffi::SQLITE_ROW => { let current_row = self.current_row.get() + 1; self.current_row.set(current_row); @@ -1146,8 +1136,8 @@ impl<'stmt> Row<'stmt> { unsafe { let idx = try!(idx.idx(self.stmt)); - if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) { - FromSql::column_result(self.stmt.stmt, idx) + if T::column_has_valid_sqlite_type(self.stmt.stmt.ptr(), idx) { + FromSql::column_result(self.stmt.stmt.ptr(), idx) } else { Err(Error::InvalidColumnType) } diff --git a/src/named_params.rs b/src/named_params.rs index 3e62610..a881646 100644 --- a/src/named_params.rs +++ b/src/named_params.rs @@ -1,7 +1,5 @@ use libc::c_int; -use super::ffi; - use {Result, Error, Connection, Statement, Rows, Row, str_to_cstring}; use types::ToSql; @@ -56,11 +54,7 @@ impl<'conn> Statement<'conn> { /// is valid but not a bound parameter of this statement. pub fn parameter_index(&self, name: &str) -> Result> { let c_name = try!(str_to_cstring(name)); - let c_index = unsafe { ffi::sqlite3_bind_parameter_index(self.stmt, c_name.as_ptr()) }; - Ok(match c_index { - 0 => None, // A zero is returned if no matching parameter is found. - n => Some(n), - }) + Ok(self.stmt.bind_parameter_index(&c_name)) } /// Execute the prepared statement with named parameter(s). If any parameters @@ -87,7 +81,7 @@ impl<'conn> Statement<'conn> { /// which case `query` should be used instead), or the underling SQLite call fails. pub fn execute_named(&mut self, params: &[(&str, &ToSql)]) -> Result { try!(self.bind_parameters_named(params)); - unsafe { self.execute_() } + self.execute_() } /// Execute the prepared statement with named parameter(s), returning an iterator over the @@ -120,7 +114,7 @@ impl<'conn> Statement<'conn> { fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> { for &(name, value) in params { if let Some(i) = try!(self.parameter_index(name)) { - try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt, i) })); + try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt.ptr(), i) })); } else { return Err(Error::InvalidParameterName(name.into())); } diff --git a/src/raw_statement.rs b/src/raw_statement.rs new file mode 100644 index 0000000..896508e --- /dev/null +++ b/src/raw_statement.rs @@ -0,0 +1,71 @@ +use std::ffi::CStr; +use std::ptr; +use libc::c_int; +use super::ffi; + +// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. +#[derive(Debug)] +pub struct RawStatement(*mut ffi::sqlite3_stmt); + +impl RawStatement { + pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement { + RawStatement(stmt) + } + + pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt { + self.0 + } + + pub fn column_count(&self) -> c_int { + unsafe { ffi::sqlite3_column_count(self.0) } + } + + pub fn column_name(&self, idx: c_int) -> &CStr { + unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.0, idx)) } + } + + pub fn step(&self) -> c_int { + unsafe { ffi::sqlite3_step(self.0) } + } + + pub fn reset(&self) -> c_int { + unsafe { ffi::sqlite3_reset(self.0) } + } + + pub fn bind_parameter_count(&self) -> c_int { + unsafe { ffi::sqlite3_bind_parameter_count(self.0) } + } + + pub fn bind_parameter_index(&self, name: &CStr) -> Option { + let r = unsafe { ffi::sqlite3_bind_parameter_index(self.0, name.as_ptr()) }; + match r { + 0 => None, + i => Some(i), + } + } + + #[cfg(feature = "cache")] + pub fn clear_bindings(&self) -> c_int { + unsafe { ffi::sqlite3_clear_bindings(self.0) } + } + + pub fn sql(&self) -> &CStr { + unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.0)) } + } + + pub fn finalize(mut self) -> c_int { + self.finalize_() + } + + fn finalize_(&mut self) -> c_int { + let r = unsafe { ffi::sqlite3_finalize(self.0) }; + self.0 = ptr::null_mut(); + r + } +} + +impl Drop for RawStatement { + fn drop(&mut self) { + self.finalize_(); + } +}