Merge pull request #22 from jgallagher/load-extension

Add calls to load SQLite extensions
This commit is contained in:
John Gallagher 2015-02-23 21:39:24 -05:00
commit 17200cf578
10 changed files with 192 additions and 27 deletions

View File

@ -12,6 +12,12 @@ license = "MIT"
[lib] [lib]
name = "rusqlite" name = "rusqlite"
[features]
load_extension = ["libsqlite3-sys/load_extension"]
[dependencies] [dependencies]
time = "~0.1.0" time = "~0.1.0"
bitflags = "~0.1" bitflags = "~0.1"
[dependencies.libsqlite3-sys]
path = "libsqlite3-sys"

5
build.rs Normal file
View File

@ -0,0 +1,5 @@
extern crate "pkg-config" as pkg_config;
fn main() {
pkg_config::find_library("sqlite3").unwrap();
}

3
libsqlite3-sys/.gitignore vendored Normal file
View File

@ -0,0 +1,3 @@
/target/
/doc/
Cargo.lock

14
libsqlite3-sys/Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "libsqlite3-sys"
version = "0.0.9"
authors = ["John Gallagher <jgallagher@bignerdranch.com>"]
description = "Native bindings to the libsqlite3 library"
license = "MIT"
links = "sqlite3"
build = "build.rs"
[features]
load_extension = []
[build-dependencies]
pkg-config = "~0.2"

5
libsqlite3-sys/build.rs Normal file
View File

@ -0,0 +1,5 @@
extern crate "pkg-config" as pkg_config;
fn main() {
pkg_config::find_library("sqlite3").unwrap();
}

View File

@ -1,6 +1,5 @@
#![allow(raw_pointer_derive)] #![allow(raw_pointer_derive, non_snake_case, non_camel_case_types)]
/* Running `target/bindgen /Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/usr/include/sqlite3.h -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/6.0/include` */ /* Running `target/bindgen /Applications/Xcode.app/Contents/Developer/Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS.sdk/usr/include/sqlite3.h -I/Applications/Xcode.app/Contents/Developer/Toolchains/XcodeDefault.xctoolchain/usr/lib/clang/6.0/include` */
/* automatically generated by rust-bindgen */
#[derive(Copy)] #[derive(Copy)]
pub enum Struct_sqlite3 { } pub enum Struct_sqlite3 { }
@ -521,7 +520,6 @@ pub struct Struct_sqlite3_rtree_geometry {
pub xDelUser: ::std::option::Option<extern "C" fn pub xDelUser: ::std::option::Option<extern "C" fn
(arg1: *mut ::libc::c_void)>, (arg1: *mut ::libc::c_void)>,
} }
#[link(name = "sqlite3")]
extern "C" { extern "C" {
pub static mut sqlite3_version: *const ::libc::c_char; pub static mut sqlite3_version: *const ::libc::c_char;
pub static mut sqlite3_temp_directory: *mut ::libc::c_char; pub static mut sqlite3_temp_directory: *mut ::libc::c_char;
@ -1589,11 +1587,13 @@ extern "C" {
pPrimaryKey: *mut ::libc::c_int, pPrimaryKey: *mut ::libc::c_int,
pAutoinc: *mut ::libc::c_int) pAutoinc: *mut ::libc::c_int)
-> ::libc::c_int; -> ::libc::c_int;
#[cfg(feature = "load_extension")]
pub fn sqlite3_load_extension(db: *mut sqlite3, pub fn sqlite3_load_extension(db: *mut sqlite3,
zFile: *const ::libc::c_char, zFile: *const ::libc::c_char,
zProc: *const ::libc::c_char, zProc: *const ::libc::c_char,
pzErrMsg: *mut *mut ::libc::c_char) pzErrMsg: *mut *mut ::libc::c_char)
-> ::libc::c_int; -> ::libc::c_int;
#[cfg(feature = "load_extension")]
pub fn sqlite3_enable_load_extension(db: *mut sqlite3, pub fn sqlite3_enable_load_extension(db: *mut sqlite3,
onoff: ::libc::c_int) onoff: ::libc::c_int)
-> ::libc::c_int; -> ::libc::c_int;

View File

@ -1,4 +1,10 @@
#![feature(libc)]
#![allow(non_snake_case)]
extern crate libc;
pub use self::bindgen::*; pub use self::bindgen::*;
use std::mem; use std::mem;
use libc::{c_int, c_void}; use libc::{c_int, c_void};
@ -42,7 +48,7 @@ pub const SQLITE_NULL : c_int = 5;
pub type SqliteDestructor = extern "C" fn(*mut c_void); pub type SqliteDestructor = extern "C" fn(*mut c_void);
pub fn SQLITE_TRANSIENT() -> SqliteDestructor { pub fn SQLITE_TRANSIENT() -> SqliteDestructor {
unsafe { mem::transmute(-1is) } unsafe { mem::transmute(-1isize) }
} }
pub fn code_to_str(code: c_int) -> &'static str { pub fn code_to_str(code: c_int) -> &'static str {

View File

@ -48,20 +48,21 @@
//! } //! }
//! } //! }
//! ``` //! ```
#![feature(unsafe_destructor, core, std_misc, libc, rustc_private, collections, hash)] #![feature(unsafe_destructor, core, std_misc, path, libc, rustc_private, collections)]
#![cfg_attr(test, feature(test))] #![cfg_attr(test, feature(test))]
extern crate libc; extern crate libc;
extern crate "libsqlite3-sys" as ffi;
#[macro_use] extern crate rustc_bitflags; #[macro_use] extern crate rustc_bitflags;
use std::mem; use std::mem;
use std::ptr; use std::ptr;
use std::fmt; use std::fmt;
use std::path::{Path};
use std::error; use std::error;
use std::rc::{Rc}; use std::rc::{Rc};
use std::cell::{RefCell, Cell}; use std::cell::{RefCell, Cell};
use std::ffi::{CString}; use std::ffi::{CStr, CString};
use std::ffi as std_ffi;
use std::str; use std::str;
use libc::{c_int, c_void, c_char}; use libc::{c_int, c_void, c_char};
@ -73,17 +74,17 @@ pub use transaction::{SqliteTransactionBehavior,
SqliteTransactionImmediate, SqliteTransactionImmediate,
SqliteTransactionExclusive}; SqliteTransactionExclusive};
#[cfg(feature = "load_extension")] pub use load_extension_guard::{SqliteLoadExtensionGuard};
pub mod types; pub mod types;
mod transaction; mod transaction;
#[cfg(feature = "load_extension")] mod load_extension_guard;
/// Automatically generated FFI bindings (via [bindgen](https://github.com/crabtw/rust-bindgen)).
#[allow(dead_code,non_snake_case,non_camel_case_types)] pub mod ffi;
/// A typedef of the result returned by many methods. /// A typedef of the result returned by many methods.
pub type SqliteResult<T> = Result<T, SqliteError>; pub type SqliteResult<T> = Result<T, SqliteError>;
unsafe fn errmsg_to_string(errmsg: *const c_char) -> String { unsafe fn errmsg_to_string(errmsg: *const c_char) -> String {
let c_slice = std_ffi::c_str_to_bytes(&errmsg); let c_slice = CStr::from_ptr(errmsg).to_bytes();
let utf8_str = str::from_utf8(c_slice); let utf8_str = str::from_utf8(c_slice);
utf8_str.unwrap_or("Invalid string encoding").to_string() utf8_str.unwrap_or("Invalid string encoding").to_string()
} }
@ -123,6 +124,21 @@ impl SqliteError {
} }
} }
fn str_to_cstring(s: &str) -> SqliteResult<CString> {
CString::new(s).map_err(|_| SqliteError{
code: ffi::SQLITE_MISUSE,
message: "Could not convert path to C-combatible string".to_string()
})
}
fn path_to_cstring(p: &Path) -> SqliteResult<CString> {
let s = try!(p.to_str().ok_or(SqliteError{
code: ffi::SQLITE_MISUSE,
message: "Could not convert path to UTF-8 string".to_string()
}));
str_to_cstring(s)
}
/// A connection to a SQLite database. /// A connection to a SQLite database.
/// ///
/// ## Warning /// ## Warning
@ -322,6 +338,55 @@ impl SqliteConnection {
db.close() db.close()
} }
/// Enable loading of SQLite extensions. Strongly consider using `SqliteLoadExtensionGuard`
/// instead of this function.
///
/// ## Example
///
/// ```rust,no_run
/// # use rusqlite::{SqliteConnection, SqliteResult};
/// # use std::path::{Path};
/// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> {
/// try!(conn.load_extension_enable());
/// try!(conn.load_extension(Path::new("my_sqlite_extension"), None));
/// conn.load_extension_disable()
/// }
/// ```
#[cfg(feature = "load_extension")]
pub fn load_extension_enable(&self) -> SqliteResult<()> {
self.db.borrow_mut().enable_load_extension(1)
}
/// Disable loading of SQLite extensions.
///
/// See `load_extension_enable` for an example.
#[cfg(feature = "load_extension")]
pub fn load_extension_disable(&self) -> SqliteResult<()> {
self.db.borrow_mut().enable_load_extension(0)
}
/// Load the SQLite extension at `dylib_path`. `dylib_path` is passed through to
/// `sqlite3_load_extension`, which may attempt OS-specific modifications if the file
/// cannot be loaded directly.
///
/// If `entry_point` is `None`, SQLite will attempt to find the entry point. If it is not
/// `None`, the entry point will be passed through to `sqlite3_load_extension`.
///
/// ## Example
///
/// ```rust,no_run
/// # use rusqlite::{SqliteConnection, SqliteResult, SqliteLoadExtensionGuard};
/// # use std::path::{Path};
/// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> {
/// let _guard = try!(SqliteLoadExtensionGuard::new(conn));
///
/// conn.load_extension(Path::new("my_sqlite_extension"), None)
/// }
#[cfg(feature = "load_extension")]
pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> SqliteResult<()> {
self.db.borrow_mut().load_extension(dylib_path, entry_point)
}
fn decode_result(&self, code: c_int) -> SqliteResult<()> { fn decode_result(&self, code: c_int) -> SqliteResult<()> {
self.db.borrow_mut().decode_result(code) self.db.borrow_mut().decode_result(code)
} }
@ -360,7 +425,7 @@ bitflags! {
impl InnerSqliteConnection { impl InnerSqliteConnection {
fn open_with_flags(path: &str, flags: SqliteOpenFlags) -> SqliteResult<InnerSqliteConnection> { fn open_with_flags(path: &str, flags: SqliteOpenFlags) -> SqliteResult<InnerSqliteConnection> {
let c_path = CString::from_slice(path.as_bytes()); let c_path = try!(str_to_cstring(path));
unsafe { unsafe {
let mut db: *mut ffi::sqlite3 = mem::uninitialized(); let mut db: *mut ffi::sqlite3 = mem::uninitialized();
let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), ptr::null()); let r = ffi::sqlite3_open_v2(c_path.as_ptr(), &mut db, flags.bits(), ptr::null());
@ -394,6 +459,16 @@ impl InnerSqliteConnection {
} }
} }
unsafe fn decode_result_with_errmsg(&self, code: c_int, errmsg: *mut c_char) -> SqliteResult<()> {
if code == ffi::SQLITE_OK {
Ok(())
} else {
let message = errmsg_to_string(&*errmsg);
ffi::sqlite3_free(errmsg as *mut c_void);
Err(SqliteError{ code: code, message: message })
}
}
fn close(&mut self) -> SqliteResult<()> { fn close(&mut self) -> SqliteResult<()> {
let r = unsafe { ffi::sqlite3_close(self.db) }; let r = unsafe { ffi::sqlite3_close(self.db) };
self.db = ptr::null_mut(); self.db = ptr::null_mut();
@ -401,18 +476,33 @@ impl InnerSqliteConnection {
} }
fn execute_batch(&mut self, sql: &str) -> SqliteResult<()> { fn execute_batch(&mut self, sql: &str) -> SqliteResult<()> {
let c_sql = CString::from_slice(sql.as_bytes()); let c_sql = try!(str_to_cstring(sql));
unsafe { unsafe {
let mut errmsg: *mut c_char = mem::uninitialized(); let mut errmsg: *mut c_char = mem::uninitialized();
let r = ffi::sqlite3_exec(self.db, c_sql.as_ptr(), None, ptr::null_mut(), &mut errmsg); let r = ffi::sqlite3_exec(self.db, c_sql.as_ptr(), None, ptr::null_mut(), &mut errmsg);
if r == ffi::SQLITE_OK { self.decode_result_with_errmsg(r, errmsg)
Ok(())
} else {
let message = errmsg_to_string(&*errmsg);
ffi::sqlite3_free(errmsg as *mut c_void);
Err(SqliteError{ code: r, message: message })
} }
} }
#[cfg(feature = "load_extension")]
fn enable_load_extension(&mut self, onoff: c_int) -> SqliteResult<()> {
let r = unsafe { ffi::sqlite3_enable_load_extension(self.db, onoff) };
self.decode_result(r)
}
#[cfg(feature = "load_extension")]
fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> SqliteResult<()> {
let dylib_str = try!(path_to_cstring(dylib_path));
unsafe {
let mut errmsg: *mut c_char = mem::uninitialized();
let r = if let Some(entry_point) = entry_point {
let c_entry = try!(str_to_cstring(entry_point));
ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), c_entry.as_ptr(), &mut errmsg)
} else {
ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg)
};
self.decode_result_with_errmsg(r, errmsg)
}
} }
fn last_insert_rowid(&self) -> i64 { fn last_insert_rowid(&self) -> i64 {
@ -425,7 +515,7 @@ impl InnerSqliteConnection {
conn: &'a SqliteConnection, conn: &'a SqliteConnection,
sql: &str) -> SqliteResult<SqliteStatement<'a>> { sql: &str) -> SqliteResult<SqliteStatement<'a>> {
let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() }; let mut c_stmt: *mut ffi::sqlite3_stmt = unsafe { mem::uninitialized() };
let c_sql = CString::from_slice(sql.as_bytes()); let c_sql = try!(str_to_cstring(sql));
let r = unsafe { let r = unsafe {
let len_with_nul = (sql.len() + 1) as c_int; let len_with_nul = (sql.len() + 1) as c_int;
ffi::sqlite3_prepare_v2(self.db, c_sql.as_ptr(), len_with_nul, &mut c_stmt, ffi::sqlite3_prepare_v2(self.db, c_sql.as_ptr(), len_with_nul, &mut c_stmt,
@ -705,6 +795,7 @@ impl<'stmt> SqliteRow<'stmt> {
#[cfg(test)] #[cfg(test)]
mod test { mod test {
extern crate "libsqlite3-sys" as ffi;
use super::*; use super::*;
fn checked_memory_handle() -> SqliteConnection { fn checked_memory_handle() -> SqliteConnection {

View File

@ -0,0 +1,34 @@
use {SqliteResult, SqliteConnection};
/// RAII guard temporarily enabling SQLite extensions to be loaded.
///
/// ## Example
///
/// ```rust,no_run
/// # use rusqlite::{SqliteConnection, SqliteResult, SqliteLoadExtensionGuard};
/// # use std::path::{Path};
/// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> {
/// let _guard = try!(SqliteLoadExtensionGuard::new(conn));
///
/// conn.load_extension(Path::new("my_sqlite_extension"), None)
/// }
/// ```
pub struct SqliteLoadExtensionGuard<'conn> {
conn: &'conn SqliteConnection,
}
impl<'conn> SqliteLoadExtensionGuard<'conn> {
/// Attempt to enable loading extensions. Loading extensions will be disabled when this
/// guard goes out of scope. Cannot be meaningfully nested.
pub fn new(conn: &SqliteConnection) -> SqliteResult<SqliteLoadExtensionGuard> {
conn.load_extension_enable().map(|_| SqliteLoadExtensionGuard{ conn: conn })
}
}
#[unsafe_destructor]
#[allow(unused_must_use)]
impl<'conn> Drop for SqliteLoadExtensionGuard<'conn> {
fn drop(&mut self) {
self.conn.load_extension_disable();
}
}

View File

@ -55,12 +55,11 @@
extern crate time; extern crate time;
use libc::{c_int, c_double, c_char}; use libc::{c_int, c_double, c_char};
use std::ffi as std_ffi; use std::ffi::{CStr};
use std::ffi::{CString};
use std::mem; use std::mem;
use std::str; use std::str;
use super::ffi; use super::ffi;
use super::{SqliteResult, SqliteError}; use super::{SqliteResult, SqliteError, str_to_cstring};
const SQLITE_DATETIME_FMT: &'static str = "%Y-%m-%d %H:%M:%S"; const SQLITE_DATETIME_FMT: &'static str = "%Y-%m-%d %H:%M:%S";
@ -90,8 +89,11 @@ raw_to_impl!(c_double, sqlite3_bind_double);
impl<'a> ToSql for &'a str { impl<'a> ToSql for &'a str {
unsafe fn bind_parameter(&self, stmt: *mut ffi::sqlite3_stmt, col: c_int) -> c_int { unsafe fn bind_parameter(&self, stmt: *mut ffi::sqlite3_stmt, col: c_int) -> c_int {
let c_str = CString::from_slice(self.as_bytes()); if let Ok(c_str) = str_to_cstring(self) {
ffi::sqlite3_bind_text(stmt, col, c_str.as_ptr(), -1, Some(ffi::SQLITE_TRANSIENT())) ffi::sqlite3_bind_text(stmt, col, c_str.as_ptr(), -1, Some(ffi::SQLITE_TRANSIENT()))
} else {
ffi::SQLITE_MISUSE
}
} }
} }
@ -177,8 +179,7 @@ impl FromSql for String {
if c_text.is_null() { if c_text.is_null() {
Ok("".to_string()) Ok("".to_string())
} else { } else {
let c_text = c_text as *const c_char; let c_slice = CStr::from_ptr(c_text as *const c_char).to_bytes();
let c_slice = std_ffi::c_str_to_bytes(&c_text);
let utf8_str = str::from_utf8(c_slice); let utf8_str = str::from_utf8(c_slice);
utf8_str utf8_str
.map(|s| { s.to_string() }) .map(|s| { s.to_string() })