Merge remote-tracking branch 'jgallagher/master' into tail

This commit is contained in:
gwenn 2019-08-26 20:43:39 +02:00
commit c6a5fd402c
8 changed files with 67 additions and 25 deletions

View File

@ -28,7 +28,7 @@ script:
- cargo build --features sqlcipher - cargo build --features sqlcipher
- cargo build --features "bundled sqlcipher" - cargo build --features "bundled sqlcipher"
- cargo test - cargo test
- cargo test --features "backup blob" - cargo test --features "backup blob extra_check"
- cargo test --features "collation functions" - cargo test --features "collation functions"
- cargo test --features "hooks limits" - cargo test --features "hooks limits"
- cargo test --features load_extension - cargo test --features load_extension

View File

@ -48,12 +48,13 @@ csvtab = ["csv", "vtab"]
# pointer passing interfaces: 3.20.0 # pointer passing interfaces: 3.20.0
array = ["vtab"] array = ["vtab"]
# session extension: 3.13.0 # session extension: 3.13.0
session = ["libsqlite3-sys/session", "hooks"] #session = ["libsqlite3-sys/session", "hooks"]
# window functions: 3.25.0 # window functions: 3.25.0
window = ["functions"] window = ["functions"]
# 3.9.0 # 3.9.0
series = ["vtab"] series = ["vtab"]
# check for invalid query.
extra_check = []
[dependencies] [dependencies]
time = "0.1.0" time = "0.1.0"

View File

@ -125,7 +125,9 @@ impl InnerConnection {
str::from_utf8_unchecked(c_slice) str::from_utf8_unchecked(c_slice)
}; };
callback(&conn, collation_name) callback(&conn, collation_name)
}).is_err() { })
.is_err()
{
return; // FIXME How ? return; // FIXME How ?
} }
} }

View File

@ -102,7 +102,8 @@ pub enum Error {
GetAuxWrongType, GetAuxWrongType,
/// Error when the SQL contains multiple statements. /// Error when the SQL contains multiple statements.
MultipleStatement,} MultipleStatement,
}
impl PartialEq for Error { impl PartialEq for Error {
fn eq(&self, other: &Error) -> bool { fn eq(&self, other: &Error) -> bool {

View File

@ -1,6 +1,6 @@
use std::ffi::CString; use std::ffi::CString;
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::os::raw::c_int; use std::os::raw::{c_char, c_int};
#[cfg(feature = "load_extension")] #[cfg(feature = "load_extension")]
use std::path::Path; use std::path::Path;
use std::ptr; use std::ptr;
@ -178,8 +178,6 @@ impl InnerConnection {
#[cfg(feature = "load_extension")] #[cfg(feature = "load_extension")]
pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> { pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> {
use std::os::raw::c_char;
let dylib_str = super::path_to_cstring(dylib_path)?; let dylib_str = super::path_to_cstring(dylib_path)?;
unsafe { unsafe {
let mut errmsg = MaybeUninit::uninit(); let mut errmsg = MaybeUninit::uninit();
@ -217,7 +215,7 @@ impl InnerConnection {
pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> { pub fn prepare<'a>(&mut self, conn: &'a Connection, sql: &str) -> Result<Statement<'a>> {
let mut c_stmt = MaybeUninit::uninit(); let mut c_stmt = MaybeUninit::uninit();
let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?; let (c_sql, len, _) = str_for_sqlite(sql.as_bytes())?;
let mut c_tail = ptr::null(); let mut c_tail = MaybeUninit::uninit();
let r = unsafe { let r = unsafe {
if cfg!(feature = "unlock_notify") { if cfg!(feature = "unlock_notify") {
let mut rc; let mut rc;
@ -227,7 +225,7 @@ impl InnerConnection {
c_sql, c_sql,
len, len,
c_stmt.as_mut_ptr(), c_stmt.as_mut_ptr(),
&mut c_tail, c_tail.as_mut_ptr(),
); );
if !unlock_notify::is_locked(self.db, rc) { if !unlock_notify::is_locked(self.db, rc) {
break; break;
@ -239,16 +237,22 @@ impl InnerConnection {
} }
rc rc
} else { } else {
ffi::sqlite3_prepare_v2(self.db(), c_sql, len, c_stmt.as_mut_ptr(), &mut c_tail) ffi::sqlite3_prepare_v2(
self.db(),
c_sql,
len,
c_stmt.as_mut_ptr(),
c_tail.as_mut_ptr(),
)
} }
}; };
self.decode_result(r)?;
let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() }; let c_stmt: *mut ffi::sqlite3_stmt = unsafe { c_stmt.assume_init() };
if !c_tail.is_null() && unsafe { *c_tail == 0 } { let c_tail: *const c_char = unsafe { c_tail.assume_init() };
// '\0' when there is no ';' at the end // TODO ignore spaces, comments, ... at the end
c_tail = ptr::null(); // TODO ignore spaces, comments, ... at the end let tail = !c_tail.is_null() && unsafe { c_tail != c_sql.offset(len as isize) };
} Ok(Statement::new(conn, RawStatement::new(c_stmt, tail)))
self.decode_result(r)
.map(|_| Statement::new(conn, RawStatement::new(c_stmt, c_tail)))
} }
pub fn changes(&mut self) -> usize { pub fn changes(&mut self) -> usize {

View File

@ -896,7 +896,8 @@ mod test {
) )
.expect("create temp db"); .expect("create temp db");
let mut db1 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap(); let mut db1 =
Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_WRITE).unwrap();
let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).unwrap(); let mut db2 = Connection::open_with_flags(&path, OpenFlags::SQLITE_OPEN_READ_ONLY).unwrap();
db1.busy_timeout(Duration::from_millis(0)).unwrap(); db1.busy_timeout(Duration::from_millis(0)).unwrap();
@ -1066,7 +1067,8 @@ mod test {
.execute( .execute(
"CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)", "CREATE TABLE foo(x INTEGER); CREATE TABLE foo(x INTEGER)",
NO_PARAMS, NO_PARAMS,
).unwrap_err(); )
.unwrap_err();
match err { match err {
Error::MultipleStatement => (), Error::MultipleStatement => (),
_ => panic!("Unexpected error: {}", err), _ => panic!("Unexpected error: {}", err),

View File

@ -2,15 +2,15 @@ use super::ffi;
use super::unlock_notify; use super::unlock_notify;
use super::StatementStatus; use super::StatementStatus;
use std::ffi::CStr; use std::ffi::CStr;
use std::os::raw::{c_char, c_int}; use std::os::raw::c_int;
use std::ptr; use std::ptr;
// Private newtype for raw sqlite3_stmts that finalize themselves when dropped. // Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
#[derive(Debug)] #[derive(Debug)]
pub struct RawStatement(*mut ffi::sqlite3_stmt, *const c_char); pub struct RawStatement(*mut ffi::sqlite3_stmt, bool);
impl RawStatement { impl RawStatement {
pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: *const c_char) -> RawStatement { pub fn new(stmt: *mut ffi::sqlite3_stmt, tail: bool) -> RawStatement {
RawStatement(stmt, tail) RawStatement(stmt, tail)
} }
@ -131,7 +131,7 @@ impl RawStatement {
} }
pub fn has_tail(&self) -> bool { pub fn has_tail(&self) -> bool {
!self.1.is_null() self.1
} }
} }

View File

@ -511,6 +511,7 @@ impl Statement<'_> {
} }
fn execute_with_bound_parameters(&mut self) -> Result<usize> { fn execute_with_bound_parameters(&mut self) -> Result<usize> {
self.check_update()?;
let r = self.stmt.step(); let r = self.stmt.step();
self.stmt.reset(); self.stmt.reset();
match r { match r {
@ -527,7 +528,7 @@ impl Statement<'_> {
} }
fn finalize_(&mut self) -> Result<()> { fn finalize_(&mut self) -> Result<()> {
let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
self.conn.decode_result(stmt.finalize()) self.conn.decode_result(stmt.finalize())
} }
@ -547,6 +548,30 @@ impl Statement<'_> {
Ok(()) Ok(())
} }
#[cfg(all(feature = "bundled", feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
if self.column_count() > 0 || self.stmt.readonly() {
return Err(Error::ExecuteReturnedResults);
}
Ok(())
}
#[cfg(all(not(feature = "bundled"), feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
if self.column_count() > 0 {
return Err(Error::ExecuteReturnedResults);
}
Ok(())
}
#[cfg(not(feature = "extra_check"))]
#[inline]
fn check_update(&self) -> Result<()> {
Ok(())
}
/// Returns a string containing the SQL text of prepared statement with /// Returns a string containing the SQL text of prepared statement with
/// bound parameters expanded. /// bound parameters expanded.
#[cfg(feature = "bundled")] #[cfg(feature = "bundled")]
@ -574,6 +599,7 @@ impl Statement<'_> {
self.stmt.get_status(status, true) self.stmt.get_status(status, true)
} }
#[cfg(feature = "extra_check")]
pub(crate) fn check_no_tail(&self) -> Result<()> { pub(crate) fn check_no_tail(&self) -> Result<()> {
if self.stmt.has_tail() { if self.stmt.has_tail() {
Err(Error::MultipleStatement) Err(Error::MultipleStatement)
@ -581,11 +607,17 @@ impl Statement<'_> {
Ok(()) Ok(())
} }
} }
#[cfg(not(feature = "extra_check"))]
#[inline]
pub(crate) fn check_no_tail(&self) -> Result<()> {
Ok(())
}
} }
impl Into<RawStatement> for Statement<'_> { impl Into<RawStatement> for Statement<'_> {
fn into(mut self) -> RawStatement { fn into(mut self) -> RawStatement {
let mut stmt = RawStatement::new(ptr::null_mut(), ptr::null()); let mut stmt = RawStatement::new(ptr::null_mut(), false);
mem::swap(&mut stmt, &mut self.stmt); mem::swap(&mut stmt, &mut self.stmt);
stmt stmt
} }