Rework ToSql to be implementable without unsafe.

This commit is contained in:
John Gallagher 2016-05-25 22:57:43 -04:00
parent 9e49452300
commit e4926ac0d7
10 changed files with 206 additions and 122 deletions

View File

@ -52,10 +52,9 @@ use std::io;
use std::cmp::min; use std::cmp::min;
use std::mem; use std::mem;
use std::ptr; use std::ptr;
use libc::c_int;
use super::ffi; use super::ffi;
use super::types::ToSql; use super::types::{ToSql, ToSqlOutput};
use {Result, Connection, DatabaseName}; use {Result, Connection, DatabaseName};
/// Handle to an open BLOB. /// Handle to an open BLOB.
@ -244,9 +243,9 @@ impl<'conn> Drop for Blob<'conn> {
pub struct ZeroBlob(pub i32); pub struct ZeroBlob(pub i32);
impl ToSql for ZeroBlob { impl ToSql for ZeroBlob {
unsafe fn bind_parameter(&self, stmt: *mut ffi::sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let ZeroBlob(length) = *self; let ZeroBlob(length) = *self;
ffi::sqlite3_bind_zeroblob(stmt, col, length) Ok(ToSqlOutput::ZeroBlob(length))
} }
} }

View File

@ -182,7 +182,7 @@ impl<'a> ValueRef<'a> {
ValueRef::Blob(from_raw_parts(blob as *const u8, len as usize)) ValueRef::Blob(from_raw_parts(blob as *const u8, len as usize))
} }
_ => unreachable!("sqlite3_value_type returned invalid value") _ => unreachable!("sqlite3_value_type returned invalid value"),
} }
} }
} }
@ -219,7 +219,7 @@ impl<'a> Context<'a> {
let value = unsafe { ValueRef::from_value(arg) }; let value = unsafe { ValueRef::from_value(arg) };
FromSql::column_result(value).map_err(|err| match err { FromSql::column_result(value).map_err(|err| match err {
Error::InvalidColumnType => Error::InvalidFunctionParameterType, Error::InvalidColumnType => Error::InvalidFunctionParameterType,
_ => err _ => err,
}) })
} }

View File

@ -73,9 +73,9 @@ use std::cell::RefCell;
use std::ffi::{CStr, CString}; use std::ffi::{CStr, CString};
use std::result; use std::result;
use std::str; use std::str;
use libc::{c_int, c_char}; use libc::{c_int, c_char, c_void};
use types::{ToSql, FromSql, ValueRef}; use types::{ToSql, ToSqlOutput, FromSql, ValueRef};
use error::{error_from_sqlite_code, error_from_handle}; use error::{error_from_sqlite_code, error_from_handle};
use raw_statement::RawStatement; use raw_statement::RawStatement;
use cache::StatementCache; use cache::StatementCache;
@ -874,6 +874,54 @@ impl<'conn> Statement<'conn> {
self.finalize_() self.finalize_()
} }
fn bind_parameter(&self, param: &ToSql, col: c_int) -> Result<()> {
// This should be
// let value = try!(param.to_sql());
// but that hits a bug in the Rust compiler around re-exported
// trait visibility. It's fixed in 1.9.
let value = try!(ToSql::to_sql(param));
let ptr = unsafe { self.stmt.ptr() };
let value = match value {
ToSqlOutput::Borrowed(v) => v,
ToSqlOutput::Owned(ref v) => ValueRef::from(v),
#[cfg(feature = "blob")]
ToSqlOutput::ZeroBlob(len) => {
return self.conn.decode_result(unsafe { ffi::sqlite3_bind_zeroblob(ptr, col, len) });
}
};
self.conn.decode_result(match value {
ValueRef::Null => unsafe { ffi::sqlite3_bind_null(ptr, col) },
ValueRef::Integer(i) => unsafe { ffi::sqlite3_bind_int64(ptr, col, i) },
ValueRef::Real(r) => unsafe { ffi::sqlite3_bind_double(ptr, col, r) },
ValueRef::Text(ref s) => unsafe {
let length = s.len();
if length > ::std::i32::MAX as usize {
ffi::SQLITE_TOOBIG
} else {
let c_str = try!(str_to_cstring(s));
let destructor = if length > 0 {
ffi::SQLITE_TRANSIENT()
} else {
ffi::SQLITE_STATIC()
};
ffi::sqlite3_bind_text(ptr, col, c_str.as_ptr(), length as c_int, destructor)
}
},
ValueRef::Blob(ref b) => unsafe {
let length = b.len();
if length > ::std::i32::MAX as usize {
ffi::SQLITE_TOOBIG
} else if length == 0 {
ffi::sqlite3_bind_zeroblob(ptr, col, 0)
} else {
ffi::sqlite3_bind_blob(ptr, col, b.as_ptr() as *const c_void, length as c_int, ffi::SQLITE_TRANSIENT())
}
},
})
}
fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> { fn bind_parameters(&mut self, params: &[&ToSql]) -> Result<()> {
assert!(params.len() as c_int == self.stmt.bind_parameter_count(), assert!(params.len() as c_int == self.stmt.bind_parameter_count(),
"incorrect number of parameters to query(): expected {}, got {}", "incorrect number of parameters to query(): expected {}, got {}",
@ -881,14 +929,7 @@ impl<'conn> Statement<'conn> {
params.len()); params.len());
for (i, p) in params.iter().enumerate() { for (i, p) in params.iter().enumerate() {
try!(unsafe { try!(self.bind_parameter(*p, (i + 1) as c_int));
self.conn.decode_result(
// This should be
// `p.bind_parameter(self.stmt.ptr(), (i + 1) as c_int)`
// but that doesn't compile until Rust 1.9 due to a compiler bug.
ToSql::bind_parameter(*p, self.stmt.ptr(), (i + 1) as c_int)
)
});
} }
Ok(()) Ok(())
@ -1128,7 +1169,7 @@ impl<'a> ValueRef<'a> {
ValueRef::Blob(from_raw_parts(blob as *const u8, len as usize)) ValueRef::Blob(from_raw_parts(blob as *const u8, len as usize))
} }
_ => unreachable!("sqlite3_column_type returned invalid value") _ => unreachable!("sqlite3_column_type returned invalid value"),
} }
} }
} }

View File

@ -204,12 +204,7 @@ impl<'conn> Statement<'conn> {
fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> { fn bind_parameters_named(&mut self, params: &[(&str, &ToSql)]) -> Result<()> {
for &(name, value) in params { for &(name, value) in params {
if let Some(i) = try!(self.parameter_index(name)) { if let Some(i) = try!(self.parameter_index(name)) {
try!(self.conn.decode_result(unsafe { try!(self.bind_parameter(value, i));
// This should be
// `value.bind_parameter(self.stmt.ptr(), i)`
// but that doesn't compile until Rust 1.9 due to a compiler bug.
ToSql::bind_parameter(value, self.stmt.ptr(), i)
}));
} else { } else {
return Err(Error::InvalidParameterName(name.into())); return Err(Error::InvalidParameterName(name.into()));
} }

View File

@ -4,18 +4,15 @@ extern crate chrono;
use std::borrow::Cow; use std::borrow::Cow;
use self::chrono::{NaiveDate, NaiveTime, NaiveDateTime, DateTime, TimeZone, UTC, Local}; use self::chrono::{NaiveDate, NaiveTime, NaiveDateTime, DateTime, TimeZone, UTC, Local};
use libc::c_int;
use {Error, Result}; use {Error, Result};
use types::{FromSql, ToSql, ValueRef}; use types::{FromSql, ToSql, ToSqlOutput, ValueRef};
use ffi::sqlite3_stmt;
/// ISO 8601 calendar date without timezone => "YYYY-MM-DD" /// ISO 8601 calendar date without timezone => "YYYY-MM-DD"
impl ToSql for NaiveDate { impl ToSql for NaiveDate {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let date_str = self.format("%Y-%m-%d").to_string(); let date_str = self.format("%Y-%m-%d").to_string();
date_str.bind_parameter(stmt, col) Ok(ToSqlOutput::from(date_str))
} }
} }
@ -31,9 +28,9 @@ impl FromSql for NaiveDate {
/// ISO 8601 time without timezone => "HH:MM:SS.SSS" /// ISO 8601 time without timezone => "HH:MM:SS.SSS"
impl ToSql for NaiveTime { impl ToSql for NaiveTime {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let date_str = self.format("%H:%M:%S%.f").to_string(); let date_str = self.format("%H:%M:%S%.f").to_string();
date_str.bind_parameter(stmt, col) Ok(ToSqlOutput::from(date_str))
} }
} }
@ -56,9 +53,9 @@ impl FromSql for NaiveTime {
/// ISO 8601 combined date and time without timezone => "YYYY-MM-DD HH:MM:SS.SSS" /// ISO 8601 combined date and time without timezone => "YYYY-MM-DD HH:MM:SS.SSS"
impl ToSql for NaiveDateTime { impl ToSql for NaiveDateTime {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let date_str = self.format("%Y-%m-%dT%H:%M:%S%.f").to_string(); let date_str = self.format("%Y-%m-%dT%H:%M:%S%.f").to_string();
date_str.bind_parameter(stmt, col) Ok(ToSqlOutput::from(date_str))
} }
} }
@ -83,9 +80,8 @@ impl FromSql for NaiveDateTime {
/// Date and time with time zone => UTC RFC3339 timestamp ("YYYY-MM-DDTHH:MM:SS.SSS+00:00"). /// Date and time with time zone => UTC RFC3339 timestamp ("YYYY-MM-DDTHH:MM:SS.SSS+00:00").
impl<Tz: TimeZone> ToSql for DateTime<Tz> { impl<Tz: TimeZone> ToSql for DateTime<Tz> {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let utc_dt = self.with_timezone(&UTC); Ok(ToSqlOutput::from(self.with_timezone(&UTC).to_rfc3339()))
utc_dt.to_rfc3339().bind_parameter(stmt, col)
} }
} }

View File

@ -50,10 +50,8 @@
//! `FromSql` for the cases where you want to know if a value was NULL (which gets translated to //! `FromSql` for the cases where you want to know if a value was NULL (which gets translated to
//! `None`). //! `None`).
pub use ffi::sqlite3_stmt;
pub use self::from_sql::FromSql; pub use self::from_sql::FromSql;
pub use self::to_sql::ToSql; pub use self::to_sql::{ToSql, ToSqlOutput};
pub use self::value_ref::ValueRef; pub use self::value_ref::ValueRef;
mod value_ref; mod value_ref;
@ -102,6 +100,42 @@ pub enum Value {
Blob(Vec<u8>), Blob(Vec<u8>),
} }
impl From<Null> for Value {
fn from(_: Null) -> Value {
Value::Null
}
}
impl From<i32> for Value {
fn from(i: i32) -> Value {
Value::Integer(i as i64)
}
}
impl From<i64> for Value {
fn from(i: i64) -> Value {
Value::Integer(i)
}
}
impl From<f64> for Value {
fn from(f: f64) -> Value {
Value::Real(f)
}
}
impl From<String> for Value {
fn from(s: String) -> Value {
Value::Text(s)
}
}
impl From<Vec<u8>> for Value {
fn from(v: Vec<u8>) -> Value {
Value::Blob(v)
}
}
#[cfg(test)] #[cfg(test)]
#[cfg_attr(feature="clippy", allow(similar_names))] #[cfg_attr(feature="clippy", allow(similar_names))]
mod test { mod test {
@ -111,6 +145,7 @@ mod test {
use Error; use Error;
use libc::{c_int, c_double}; use libc::{c_int, c_double};
use std::f64::EPSILON; use std::f64::EPSILON;
use super::Value;
fn checked_memory_handle() -> Connection { fn checked_memory_handle() -> Connection {
let db = Connection::open_in_memory().unwrap(); let db = Connection::open_in_memory().unwrap();
@ -133,6 +168,17 @@ mod test {
fn test_str() { fn test_str() {
let db = checked_memory_handle(); let db = checked_memory_handle();
let s = "hello, world!";
db.execute("INSERT INTO foo(t) VALUES (?)", &[&s]).unwrap();
let from: String = db.query_row("SELECT t FROM foo", &[], |r| r.get(0)).unwrap();
assert_eq!(from, s);
}
#[test]
fn test_string() {
let db = checked_memory_handle();
let s = "hello, world!"; let s = "hello, world!";
db.execute("INSERT INTO foo(t) VALUES (?)", &[&s.to_owned()]).unwrap(); db.execute("INSERT INTO foo(t) VALUES (?)", &[&s.to_owned()]).unwrap();
@ -140,6 +186,15 @@ mod test {
assert_eq!(from, s); assert_eq!(from, s);
} }
#[test]
fn test_value() {
let db = checked_memory_handle();
db.execute("INSERT INTO foo(i) VALUES (?)", &[&Value::Integer(10)]).unwrap();
assert_eq!(10i64, db.query_row("SELECT i FROM foo", &[], |r| r.get(0)).unwrap());
}
#[test] #[test]
fn test_option() { fn test_option() {
let db = checked_memory_handle(); let db = checked_memory_handle();

View File

@ -1,19 +1,15 @@
//! `ToSql` and `FromSql` implementation for JSON `Value`. //! `ToSql` and `FromSql` implementation for JSON `Value`.
extern crate serde_json; extern crate serde_json;
use libc::c_int;
use self::serde_json::Value; use self::serde_json::Value;
use {Error, Result}; use {Error, Result};
use types::{FromSql, ToSql, ValueRef}; use types::{FromSql, ToSql, ToSqlOutput, ValueRef};
use ffi::sqlite3_stmt;
/// Serialize JSON `Value` to text. /// Serialize JSON `Value` to text.
impl ToSql for Value { impl ToSql for Value {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let s = serde_json::to_string(self).unwrap(); Ok(ToSqlOutput::from(serde_json::to_string(self).unwrap()))
s.bind_parameter(stmt, col)
} }
} }
@ -24,7 +20,8 @@ impl FromSql for Value {
ValueRef::Text(ref s) => serde_json::from_str(s), ValueRef::Text(ref s) => serde_json::from_str(s),
ValueRef::Blob(ref b) => serde_json::from_slice(b), ValueRef::Blob(ref b) => serde_json::from_slice(b),
_ => return Err(Error::InvalidColumnType), _ => return Err(Error::InvalidColumnType),
}.map_err(|err| Error::FromSqlConversionFailure(Box::new(err))) }
.map_err(|err| Error::FromSqlConversionFailure(Box::new(err)))
} }
} }

View File

@ -1,17 +1,14 @@
extern crate time; extern crate time;
use libc::c_int;
use {Error, Result}; use {Error, Result};
use types::{FromSql, ToSql, ValueRef}; use types::{FromSql, ToSql, ToSqlOutput, ValueRef};
use ffi::sqlite3_stmt;
const SQLITE_DATETIME_FMT: &'static str = "%Y-%m-%d %H:%M:%S"; const SQLITE_DATETIME_FMT: &'static str = "%Y-%m-%d %H:%M:%S";
impl ToSql for time::Timespec { impl ToSql for time::Timespec {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
let time_str = time::at_utc(*self).strftime(SQLITE_DATETIME_FMT).unwrap().to_string(); let time_string = time::at_utc(*self).strftime(SQLITE_DATETIME_FMT).unwrap().to_string();
time_str.bind_parameter(stmt, col) Ok(ToSqlOutput::from(time_string))
} }
} }

View File

@ -1,95 +1,87 @@
use std::mem; use super::{Null, Value, ValueRef};
use ::Result;
use libc::{c_double, c_int}; pub enum ToSqlOutput<'a> {
Borrowed(ValueRef<'a>),
Owned(Value),
use super::Null; #[cfg(feature = "blob")]
use ::{ffi, str_to_cstring}; ZeroBlob(i32),
use ffi::sqlite3_stmt; }
impl<'a, T: ?Sized> From<&'a T> for ToSqlOutput<'a> where &'a T: Into<ValueRef<'a>> {
fn from(t: &'a T) -> Self {
ToSqlOutput::Borrowed(t.into())
}
}
impl<'a, T: Into<Value>> From<T> for ToSqlOutput<'a> {
fn from(t: T) -> Self {
ToSqlOutput::Owned(t.into())
}
}
/// A trait for types that can be converted into SQLite values. /// A trait for types that can be converted into SQLite values.
pub trait ToSql { pub trait ToSql {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int; fn to_sql(&self) -> Result<ToSqlOutput>;
} }
macro_rules! raw_to_impl( // We should be able to use a generic impl like this:
($t:ty, $f:ident) => ( //
// impl<T: Copy> ToSql for T where T: Into<Value> {
// fn to_sql(&self) -> Result<ToSqlOutput> {
// Ok(ToSqlOutput::from((*self).into()))
// }
// }
//
// instead of the following macro, but this runs afoul of
// https://github.com/rust-lang/rust/issues/30191 and reports conflicting
// implementations even when there aren't any.
macro_rules! to_sql_self(
($t:ty) => (
impl ToSql for $t { impl ToSql for $t {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
ffi::$f(stmt, col, *self) Ok(ToSqlOutput::from(*self))
} }
} }
) )
); );
raw_to_impl!(c_int, sqlite3_bind_int); // i32 to_sql_self!(Null);
raw_to_impl!(i64, sqlite3_bind_int64); to_sql_self!(i32);
raw_to_impl!(c_double, sqlite3_bind_double); to_sql_self!(i64);
to_sql_self!(f64);
impl ToSql for bool { impl<'a, T: ?Sized> ToSql for &'a T where &'a T: Into<ToSqlOutput<'a>> {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
if *self { Ok(ToSqlOutput::from((*self).into()))
ffi::sqlite3_bind_int(stmt, col, 1)
} else {
ffi::sqlite3_bind_int(stmt, col, 0)
}
}
}
impl<'a> ToSql for &'a str {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int {
let length = self.len();
if length > ::std::i32::MAX as usize {
return ffi::SQLITE_TOOBIG;
}
match str_to_cstring(self) {
Ok(c_str) => {
ffi::sqlite3_bind_text(stmt,
col,
c_str.as_ptr(),
length as c_int,
ffi::SQLITE_TRANSIENT())
}
Err(_) => ffi::SQLITE_MISUSE,
}
} }
} }
impl ToSql for String { impl ToSql for String {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
(&self[..]).bind_parameter(stmt, col) Ok(ToSqlOutput::from(self.as_str()))
}
}
impl<'a> ToSql for &'a [u8] {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int {
if self.len() > ::std::i32::MAX as usize {
return ffi::SQLITE_TOOBIG;
}
ffi::sqlite3_bind_blob(stmt,
col,
mem::transmute(self.as_ptr()),
self.len() as c_int,
ffi::SQLITE_TRANSIENT())
} }
} }
impl ToSql for Vec<u8> { impl ToSql for Vec<u8> {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
(&self[..]).bind_parameter(stmt, col) Ok(ToSqlOutput::from(self.as_slice()))
}
}
impl ToSql for Value {
fn to_sql(&self) -> Result<ToSqlOutput> {
Ok(ToSqlOutput::from(self))
} }
} }
impl<T: ToSql> ToSql for Option<T> { impl<T: ToSql> ToSql for Option<T> {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int { fn to_sql(&self) -> Result<ToSqlOutput> {
match *self { match *self {
None => ffi::sqlite3_bind_null(stmt, col), None => Ok(ToSqlOutput::from(Null)),
Some(ref t) => t.bind_parameter(stmt, col), Some(ref t) => t.to_sql(),
} }
} }
} }
impl ToSql for Null {
unsafe fn bind_parameter(&self, stmt: *mut sqlite3_stmt, col: c_int) -> c_int {
ffi::sqlite3_bind_null(stmt, col)
}
}

View File

@ -70,6 +70,18 @@ impl<'a> From<ValueRef<'a>> for Value {
} }
} }
impl<'a> From<&'a str> for ValueRef<'a> {
fn from(s: &str) -> ValueRef {
ValueRef::Text(s)
}
}
impl<'a> From<&'a [u8]> for ValueRef<'a> {
fn from(s: &[u8]) -> ValueRef {
ValueRef::Blob(s)
}
}
impl<'a> From<&'a Value> for ValueRef<'a> { impl<'a> From<&'a Value> for ValueRef<'a> {
fn from(value: &'a Value) -> ValueRef<'a> { fn from(value: &'a Value) -> ValueRef<'a> {
match *value { match *value {