Extract RawStatement wrapper around *mut sqlite3_stmt.

This commit is contained in:
John Gallagher 2016-05-17 10:06:43 -05:00
parent b76196ae1a
commit f6aba80f4b
3 changed files with 106 additions and 51 deletions

View File

@ -76,6 +76,7 @@ use libc::{c_int, c_char};
use types::{ToSql, FromSql}; use types::{ToSql, FromSql};
use error::{error_from_sqlite_code, error_from_handle}; use error::{error_from_sqlite_code, error_from_handle};
use raw_statement::RawStatement;
pub use transaction::{SqliteTransaction, Transaction, TransactionBehavior}; pub use transaction::{SqliteTransaction, Transaction, TransactionBehavior};
pub use error::{SqliteError, Error}; pub use error::{SqliteError, Error};
@ -88,6 +89,7 @@ mod transaction;
mod named_params; mod named_params;
mod error; mod error;
mod convenient; mod convenient;
mod raw_statement;
#[cfg(feature = "load_extension")]mod load_extension_guard; #[cfg(feature = "load_extension")]mod load_extension_guard;
#[cfg(feature = "trace")]pub mod trace; #[cfg(feature = "trace")]pub mod trace;
#[cfg(feature = "backup")]pub mod backup; #[cfg(feature = "backup")]pub mod backup;
@ -687,7 +689,7 @@ impl InnerConnection {
&mut c_stmt, &mut c_stmt,
ptr::null_mut()) ptr::null_mut())
}; };
self.decode_result(r).map(|_| Statement::new(conn, c_stmt)) self.decode_result(r).map(|_| Statement::new(conn, RawStatement::new(c_stmt)))
} }
fn changes(&mut self) -> c_int { fn changes(&mut self) -> c_int {
@ -708,16 +710,17 @@ pub type SqliteStatement<'conn> = Statement<'conn>;
/// A prepared statement. /// A prepared statement.
pub struct Statement<'conn> { pub struct Statement<'conn> {
conn: &'conn Connection, conn: &'conn Connection,
stmt: *mut ffi::sqlite3_stmt, stmt: RawStatement,
column_count: c_int, column_count: c_int,
} }
impl<'conn> Statement<'conn> { impl<'conn> Statement<'conn> {
fn new(conn: &Connection, stmt: *mut ffi::sqlite3_stmt) -> Statement { fn new(conn: &Connection, stmt: RawStatement) -> Statement {
let column_count = stmt.column_count();
Statement { Statement {
conn: conn, conn: conn,
stmt: stmt, stmt: stmt,
column_count: unsafe { ffi::sqlite3_column_count(stmt) }, column_count: column_count,
} }
} }
@ -726,7 +729,7 @@ impl<'conn> Statement<'conn> {
let n = self.column_count; let n = self.column_count;
let mut cols = Vec::with_capacity(n as usize); let mut cols = Vec::with_capacity(n as usize);
for i in 0..n { for i in 0..n {
let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; let slice = self.stmt.column_name(i);
let s = str::from_utf8(slice.to_bytes()).unwrap(); let s = str::from_utf8(slice.to_bytes()).unwrap();
cols.push(s); cols.push(s);
} }
@ -747,8 +750,7 @@ impl<'conn> Statement<'conn> {
let bytes = name.as_bytes(); let bytes = name.as_bytes();
let n = self.column_count; let n = self.column_count;
for i in 0..n { for i in 0..n {
let slice = unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.stmt, i)) }; if bytes == self.stmt.column_name(i).to_bytes() {
if bytes == slice.to_bytes() {
return Ok(i); return Ok(i);
} }
} }
@ -779,15 +781,13 @@ impl<'conn> Statement<'conn> {
/// Will return `Err` if binding parameters fails, the executed statement returns rows (in /// Will return `Err` if binding parameters fails, the executed statement returns rows (in
/// which case `query` should be used instead), or the underling SQLite call fails. /// which case `query` should be used instead), or the underling SQLite call fails.
pub fn execute(&mut self, params: &[&ToSql]) -> Result<c_int> { pub fn execute(&mut self, params: &[&ToSql]) -> Result<c_int> {
unsafe { try!(self.bind_parameters(params));
try!(self.bind_parameters(params)); self.execute_()
self.execute_()
}
} }
unsafe fn execute_(&mut self) -> Result<c_int> { fn execute_(&mut self) -> Result<c_int> {
let r = ffi::sqlite3_step(self.stmt); let r = self.stmt.step();
ffi::sqlite3_reset(self.stmt); self.stmt.reset();
match r { match r {
ffi::SQLITE_DONE => { ffi::SQLITE_DONE => {
if self.column_count == 0 { if self.column_count == 0 {
@ -825,10 +825,7 @@ impl<'conn> Statement<'conn> {
/// ///
/// Will return `Err` if binding parameters fails. /// Will return `Err` if binding parameters fails.
pub fn query<'a>(&'a mut self, params: &[&ToSql]) -> Result<Rows<'a>> { pub fn query<'a>(&'a mut self, params: &[&ToSql]) -> Result<Rows<'a>> {
unsafe { try!(self.bind_parameters(params));
try!(self.bind_parameters(params));
}
Ok(Rows::new(self)) Ok(Rows::new(self))
} }
@ -889,14 +886,16 @@ impl<'conn> Statement<'conn> {
self.finalize_() self.finalize_()
} }
unsafe fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> {
assert!(params.len() as c_int == ffi::sqlite3_bind_parameter_count(self.stmt), assert!(params.len() as c_int == self.stmt.bind_parameter_count(),
"incorrect number of parameters to query(): expected {}, got {}", "incorrect number of parameters to query(): expected {}, got {}",
ffi::sqlite3_bind_parameter_count(self.stmt), self.stmt.bind_parameter_count(),
params.len()); params.len());
for (i, p) in params.iter().enumerate() { for (i, p) in params.iter().enumerate() {
try!(self.conn.decode_result(p.bind_parameter(self.stmt, (i + 1) as c_int))); try!(unsafe {
self.conn.decode_result(p.bind_parameter(self.stmt.ptr(), (i + 1) as c_int))
});
} }
Ok(()) Ok(())
@ -904,32 +903,25 @@ impl<'conn> Statement<'conn> {
#[cfg(feature = "cache")] #[cfg(feature = "cache")]
fn clear_bindings(&mut self) { fn clear_bindings(&mut self) {
unsafe { self.stmt.clear_bindings();
ffi::sqlite3_clear_bindings(self.stmt);
};
} }
#[cfg(feature = "cache")] #[cfg(feature = "cache")]
fn eq(&self, sql: &str) -> bool { fn eq(&self, sql: &str) -> bool {
unsafe { let c_slice = self.stmt.sql().to_bytes();
let c_slice = CStr::from_ptr(ffi::sqlite3_sql(self.stmt)).to_bytes(); sql.as_bytes().eq(c_slice)
str::from_utf8(c_slice).unwrap().eq(sql)
}
} }
fn finalize_(&mut self) -> Result<()> { fn finalize_(&mut self) -> Result<()> {
let r = unsafe { ffi::sqlite3_finalize(self.stmt) }; let mut stmt = RawStatement::new(ptr::null_mut());
self.stmt = ptr::null_mut(); mem::swap(&mut stmt, &mut self.stmt);
self.conn.decode_result(r) self.conn.decode_result(stmt.finalize())
} }
} }
impl<'conn> fmt::Debug for Statement<'conn> { impl<'conn> fmt::Debug for Statement<'conn> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let sql = unsafe { let sql = str::from_utf8(self.stmt.sql().to_bytes());
let c_slice = CStr::from_ptr(ffi::sqlite3_sql(self.stmt)).to_bytes();
str::from_utf8(c_slice)
};
f.debug_struct("Statement") f.debug_struct("Statement")
.field("conn", self.conn) .field("conn", self.conn)
.field("stmt", &self.stmt) .field("stmt", &self.stmt)
@ -1039,9 +1031,7 @@ impl<'stmt> Rows<'stmt> {
fn reset(&mut self) { fn reset(&mut self) {
if let Some(stmt) = self.stmt.take() { if let Some(stmt) = self.stmt.take() {
unsafe { stmt.stmt.reset();
ffi::sqlite3_reset(stmt.stmt);
}
} }
} }
} }
@ -1051,7 +1041,7 @@ impl<'stmt> Iterator for Rows<'stmt> {
fn next(&mut self) -> Option<Result<Row<'stmt>>> { fn next(&mut self) -> Option<Result<Row<'stmt>>> {
self.stmt.and_then(|stmt| { self.stmt.and_then(|stmt| {
match unsafe { ffi::sqlite3_step(stmt.stmt) } { match stmt.stmt.step() {
ffi::SQLITE_ROW => { ffi::SQLITE_ROW => {
let current_row = self.current_row.get() + 1; let current_row = self.current_row.get() + 1;
self.current_row.set(current_row); self.current_row.set(current_row);
@ -1146,8 +1136,8 @@ impl<'stmt> Row<'stmt> {
unsafe { unsafe {
let idx = try!(idx.idx(self.stmt)); let idx = try!(idx.idx(self.stmt));
if T::column_has_valid_sqlite_type(self.stmt.stmt, idx) { if T::column_has_valid_sqlite_type(self.stmt.stmt.ptr(), idx) {
FromSql::column_result(self.stmt.stmt, idx) FromSql::column_result(self.stmt.stmt.ptr(), idx)
} else { } else {
Err(Error::InvalidColumnType) Err(Error::InvalidColumnType)
} }

View File

@ -1,7 +1,5 @@
use libc::c_int; use libc::c_int;
use super::ffi;
use {Result, Error, Connection, Statement, Rows, Row, str_to_cstring}; use {Result, Error, Connection, Statement, Rows, Row, str_to_cstring};
use types::ToSql; use types::ToSql;
@ -56,11 +54,7 @@ impl<'conn> Statement<'conn> {
/// is valid but not a bound parameter of this statement. /// is valid but not a bound parameter of this statement.
pub fn parameter_index(&self, name: &str) -> Result<Option<i32>> { pub fn parameter_index(&self, name: &str) -> Result<Option<i32>> {
let c_name = try!(str_to_cstring(name)); let c_name = try!(str_to_cstring(name));
let c_index = unsafe { ffi::sqlite3_bind_parameter_index(self.stmt, c_name.as_ptr()) }; Ok(self.stmt.bind_parameter_index(&c_name))
Ok(match c_index {
0 => None, // A zero is returned if no matching parameter is found.
n => Some(n),
})
} }
/// Execute the prepared statement with named parameter(s). If any parameters /// Execute the prepared statement with named parameter(s). If any parameters
@ -87,7 +81,7 @@ impl<'conn> Statement<'conn> {
/// which case `query` should be used instead), or the underling SQLite call fails. /// which case `query` should be used instead), or the underling SQLite call fails.
pub fn execute_named(&mut self, params: &[(&str, &ToSql)]) -> Result<c_int> { pub fn execute_named(&mut self, params: &[(&str, &ToSql)]) -> Result<c_int> {
try!(self.bind_parameters_named(params)); try!(self.bind_parameters_named(params));
unsafe { self.execute_() } self.execute_()
} }
/// Execute the prepared statement with named parameter(s), returning an iterator over the /// Execute the prepared statement with named parameter(s), returning an iterator over the
@ -120,7 +114,7 @@ impl<'conn> Statement<'conn> {
fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> { fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> {
for &(name, value) in params { for &(name, value) in params {
if let Some(i) = try!(self.parameter_index(name)) { if let Some(i) = try!(self.parameter_index(name)) {
try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt, i) })); try!(self.conn.decode_result(unsafe { value.bind_parameter(self.stmt.ptr(), i) }));
} else { } else {
return Err(Error::InvalidParameterName(name.into())); return Err(Error::InvalidParameterName(name.into()));
} }

71
src/raw_statement.rs Normal file
View File

@ -0,0 +1,71 @@
use std::ffi::CStr;
use std::ptr;
use libc::c_int;
use super::ffi;
// Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
#[derive(Debug)]
pub struct RawStatement(*mut ffi::sqlite3_stmt);
impl RawStatement {
pub fn new(stmt: *mut ffi::sqlite3_stmt) -> RawStatement {
RawStatement(stmt)
}
pub unsafe fn ptr(&self) -> *mut ffi::sqlite3_stmt {
self.0
}
pub fn column_count(&self) -> c_int {
unsafe { ffi::sqlite3_column_count(self.0) }
}
pub fn column_name(&self, idx: c_int) -> &CStr {
unsafe { CStr::from_ptr(ffi::sqlite3_column_name(self.0, idx)) }
}
pub fn step(&self) -> c_int {
unsafe { ffi::sqlite3_step(self.0) }
}
pub fn reset(&self) -> c_int {
unsafe { ffi::sqlite3_reset(self.0) }
}
pub fn bind_parameter_count(&self) -> c_int {
unsafe { ffi::sqlite3_bind_parameter_count(self.0) }
}
pub fn bind_parameter_index(&self, name: &CStr) -> Option<c_int> {
let r = unsafe { ffi::sqlite3_bind_parameter_index(self.0, name.as_ptr()) };
match r {
0 => None,
i => Some(i),
}
}
#[cfg(feature = "cache")]
pub fn clear_bindings(&self) -> c_int {
unsafe { ffi::sqlite3_clear_bindings(self.0) }
}
pub fn sql(&self) -> &CStr {
unsafe { CStr::from_ptr(ffi::sqlite3_sql(self.0)) }
}
pub fn finalize(mut self) -> c_int {
self.finalize_()
}
fn finalize_(&mut self) -> c_int {
let r = unsafe { ffi::sqlite3_finalize(self.0) };
self.0 = ptr::null_mut();
r
}
}
impl Drop for RawStatement {
fn drop(&mut self) {
self.finalize_();
}
}