diff --git a/src/functions.rs b/src/functions.rs index c80b407..a489475 100644 --- a/src/functions.rs +++ b/src/functions.rs @@ -54,7 +54,7 @@ use std::ffi::CStr; use std::mem; use std::ptr; use std::slice; -use libc::{c_int, c_double, c_char, c_void}; +use libc::{c_int, c_char, c_void}; use ffi; pub use ffi::sqlite3_context; @@ -62,100 +62,73 @@ pub use ffi::sqlite3_value; pub use ffi::sqlite3_value_type; pub use ffi::sqlite3_value_numeric_type; -use types::{Null, FromSql, ValueRef}; +use types::{ToSql, ToSqlOutput, FromSql, ValueRef}; use {Result, Error, Connection, str_to_cstring, InnerConnection}; -/// A trait for types that can be converted into the result of an SQL function. -pub trait ToResult { - unsafe fn set_result(&self, ctx: *mut sqlite3_context); +fn set_result<'a>(ctx: *mut sqlite3_context, result: &ToSqlOutput<'a>) { + let value = match *result { + ToSqlOutput::Borrowed(v) => v, + ToSqlOutput::Owned(ref v) => ValueRef::from(v), + + #[cfg(feature = "blob")] + ToSqlOutput::ZeroBlob(len) => { + return unsafe { ffi::sqlite3_result_zeroblob(ctx, len) }; + } + }; + + match value { + ValueRef::Null => unsafe { ffi::sqlite3_result_null(ctx) }, + ValueRef::Integer(i) => unsafe { ffi::sqlite3_result_int64(ctx, i) }, + ValueRef::Real(r) => unsafe { ffi::sqlite3_result_double(ctx, r) }, + ValueRef::Text(ref s) => unsafe { + let length = s.len(); + if length > ::std::i32::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else { + let c_str = match str_to_cstring(s) { + Ok(c_str) => c_str, + // TODO sqlite3_result_error + Err(_) => return ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), + }; + let destructor = if length > 0 { + ffi::SQLITE_TRANSIENT() + } else { + ffi::SQLITE_STATIC() + }; + ffi::sqlite3_result_text(ctx, c_str.as_ptr(), length as c_int, destructor); + } + }, + ValueRef::Blob(ref b) => unsafe { + let length = b.len(); + if length > ::std::i32::MAX as usize { + ffi::sqlite3_result_error_toobig(ctx); + } else if length == 0 { + ffi::sqlite3_result_zeroblob(ctx, 0) + } else { + ffi::sqlite3_result_blob(ctx, b.as_ptr() as *const c_void, length as c_int, ffi::SQLITE_TRANSIENT()); + } + }, + } } -macro_rules! raw_to_impl( - ($t:ty, $f:ident) => ( - impl ToResult for $t { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - ffi::$f(ctx, *self) +unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) { + match err { + &Error::SqliteFailure(ref err, ref s) => { + ffi::sqlite3_result_error_code(ctx, err.extended_code); + if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } } - ) -); - -raw_to_impl!(c_int, sqlite3_result_int); -raw_to_impl!(i64, sqlite3_result_int64); -raw_to_impl!(c_double, sqlite3_result_double); - -impl<'a> ToResult for bool { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - if *self { - ffi::sqlite3_result_int(ctx, 1) - } else { - ffi::sqlite3_result_int(ctx, 0) - } - } -} - - -impl<'a> ToResult for &'a str { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - let length = self.len(); - if length > ::std::i32::MAX as usize { - ffi::sqlite3_result_error_toobig(ctx); - return; - } - match str_to_cstring(self) { - Ok(c_str) => { - ffi::sqlite3_result_text(ctx, - c_str.as_ptr(), - length as c_int, - ffi::SQLITE_TRANSIENT()) + _ => { + ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); + if let Ok(cstr) = str_to_cstring(err.description()) { + ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); } - // TODO sqlite3_result_error - Err(_) => ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_MISUSE), } } } -impl ToResult for String { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - (&self[..]).set_result(ctx) - } -} - -impl<'a> ToResult for &'a [u8] { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - if self.len() > ::std::i32::MAX as usize { - ffi::sqlite3_result_error_toobig(ctx); - return; - } - ffi::sqlite3_result_blob(ctx, - mem::transmute(self.as_ptr()), - self.len() as c_int, - ffi::SQLITE_TRANSIENT()) - } -} - -impl ToResult for Vec { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - (&self[..]).set_result(ctx) - } -} - -impl ToResult for Option { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - match *self { - None => ffi::sqlite3_result_null(ctx), - Some(ref t) => t.set_result(ctx), - } - } -} - -impl ToResult for Null { - unsafe fn set_result(&self, ctx: *mut sqlite3_context) { - ffi::sqlite3_result_null(ctx) - } -} - impl<'a> ValueRef<'a> { unsafe fn from_value(value: *mut sqlite3_value) -> ValueRef<'a> { use std::slice::from_raw_parts; @@ -259,7 +232,7 @@ impl<'a> Context<'a> { /// `A` is the type of the aggregation context and `T` is the type of the final result. /// Implementations should be stateless. pub trait Aggregate - where T: ToResult + where T: ToSql { /// Initializes the aggregation context. Will be called prior to the first call /// to `step()` to set up the context for an invocation of the function. (Note: @@ -316,7 +289,7 @@ impl Connection { x_func: F) -> Result<()> where F: FnMut(&Context) -> Result, - T: ToResult + T: ToSql { self.db.borrow_mut().create_scalar_function(fn_name, n_arg, deterministic, x_func) } @@ -333,7 +306,7 @@ impl Connection { aggr: D) -> Result<()> where D: Aggregate, - T: ToResult + T: ToSql { self.db .borrow_mut() @@ -361,13 +334,13 @@ impl InnerConnection { x_func: F) -> Result<()> where F: FnMut(&Context) -> Result, - T: ToResult + T: ToSql { unsafe extern "C" fn call_boxed_closure(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) where F: FnMut(&Context) -> Result, - T: ToResult + T: ToSql { let ctx = Context { ctx: ctx, @@ -375,20 +348,14 @@ impl InnerConnection { }; let boxed_f: *mut F = mem::transmute(ffi::sqlite3_user_data(ctx.ctx)); assert!(!boxed_f.is_null(), "Internal error - null function pointer"); - match (*boxed_f)(&ctx) { - Ok(r) => r.set_result(ctx.ctx), - Err(Error::SqliteFailure(err, s)) => { - ffi::sqlite3_result_error_code(ctx.ctx, err.extended_code); - if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { - ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); - } - } - Err(err) => { - ffi::sqlite3_result_error_code(ctx.ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); - if let Ok(cstr) = str_to_cstring(err.description()) { - ffi::sqlite3_result_error(ctx.ctx, cstr.as_ptr(), -1); - } - } + + let t = (*boxed_f)(&ctx); + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + + match t { + Ok(Ok(ref value)) => set_result(ctx.ctx, value), + Ok(Err(err)) => report_error(ctx.ctx, &err), + Err(err) => report_error(ctx.ctx, err), } } @@ -419,7 +386,7 @@ impl InnerConnection { aggr: D) -> Result<()> where D: Aggregate, - T: ToResult + T: ToSql { unsafe fn aggregate_context(ctx: *mut sqlite3_context, bytes: usize) @@ -431,28 +398,11 @@ impl InnerConnection { Some(pac) } - unsafe fn report_aggregate_error(ctx: *mut sqlite3_context, err: Error) { - match err { - Error::SqliteFailure(err, s) => { - ffi::sqlite3_result_error_code(ctx, err.extended_code); - if let Some(Ok(cstr)) = s.map(|s| str_to_cstring(&s)) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } - } - _ => { - ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION); - if let Ok(cstr) = str_to_cstring(err.description()) { - ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1); - } - } - } - } - unsafe extern "C" fn call_boxed_step(ctx: *mut sqlite3_context, argc: c_int, argv: *mut *mut sqlite3_value) where D: Aggregate, - T: ToResult + T: ToSql { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); assert!(!boxed_aggr.is_null(), @@ -477,13 +427,13 @@ impl InnerConnection { match (*boxed_aggr).step(&mut ctx, &mut **pac) { Ok(_) => {} - Err(err) => report_aggregate_error(ctx.ctx, err), + Err(err) => report_error(ctx.ctx, &err), }; } unsafe extern "C" fn call_boxed_final(ctx: *mut sqlite3_context) where D: Aggregate, - T: ToResult + T: ToSql { let boxed_aggr: *mut D = mem::transmute(ffi::sqlite3_user_data(ctx)); assert!(!boxed_aggr.is_null(), @@ -503,10 +453,13 @@ impl InnerConnection { None => None, }; - match (*boxed_aggr).finalize(a) { - Ok(r) => r.set_result(ctx), - Err(err) => report_aggregate_error(ctx, err), - }; + let t = (*boxed_aggr).finalize(a); + let t = t.as_ref().map(|t| ToSql::to_sql(t)); + match t { + Ok(Ok(ref value)) => set_result(ctx, value), + Ok(Err(err)) => report_error(ctx, &err), + Err(err) => report_error(ctx, err), + } } let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));