From 348f94e1092352c565c0cc07b9b676b6fb5cc78e Mon Sep 17 00:00:00 2001 From: John Gallagher Date: Mon, 23 Feb 2015 19:52:48 -0500 Subject: [PATCH] Add public API for loading extensions --- Cargo.toml | 3 ++ src/lib.rs | 97 ++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 90 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 11cc6a0..0d43f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,9 @@ license = "MIT" [lib] name = "rusqlite" +[features] +load_extension = ["libsqlite3-sys/load_extension"] + [dependencies] time = "~0.1.0" bitflags = "~0.1" diff --git a/src/lib.rs b/src/lib.rs index a911809..30a2c46 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,7 +48,7 @@ //! } //! } //! ``` -#![feature(unsafe_destructor, core, std_misc, libc, rustc_private, collections, hash)] +#![feature(unsafe_destructor, core, std_misc, path, libc, rustc_private, collections, hash)] #![cfg_attr(test, feature(test))] extern crate libc; @@ -58,11 +58,11 @@ extern crate "libsqlite3-sys" as ffi; use std::mem; use std::ptr; use std::fmt; +use std::path::{Path}; use std::error; use std::rc::{Rc}; use std::cell::{RefCell, Cell}; -use std::ffi::{CString}; -use std::ffi as std_ffi; +use std::ffi::{CStr, CString}; use std::str; use libc::{c_int, c_void, c_char}; @@ -81,7 +81,7 @@ mod transaction; pub type SqliteResult = Result; unsafe fn errmsg_to_string(errmsg: *const c_char) -> String { - let c_slice = std_ffi::c_str_to_bytes(&errmsg); + let c_slice = CStr::from_ptr(errmsg).to_bytes(); let utf8_str = str::from_utf8(c_slice); utf8_str.unwrap_or("Invalid string encoding").to_string() } @@ -121,6 +121,21 @@ impl SqliteError { } } +fn str_to_cstring(s: &str) -> SqliteResult { + CString::new(s).map_err(|_| SqliteError{ + code: ffi::SQLITE_MISUSE, + message: "Could not convert path to C-combatible string".to_string() + }) +} + +fn path_to_cstring(p: &Path) -> SqliteResult { + let s = try!(p.to_str().ok_or(SqliteError{ + code: ffi::SQLITE_MISUSE, + message: "Could not convert path to UTF-8 string".to_string() + })); + str_to_cstring(s) +} + /// A connection to a SQLite database. /// /// ## Warning @@ -320,6 +335,43 @@ impl SqliteConnection { db.close() } + /// Enable loading of SQLite extensions. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{SqliteConnection, SqliteResult}; + /// # use std::path::{Path}; + /// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> { + /// try!(conn.load_extension_enable()); + /// try!(conn.load_extension(Path::new("my_sqlite_extension"), None)); + /// conn.load_extension_disable() + /// } + /// ``` + #[cfg(feature = "load_extension")] + pub fn load_extension_enable(&self) -> SqliteResult<()> { + self.db.borrow_mut().enable_load_extension(1) + } + + /// Disable loading of SQLite extensions. + /// + /// See `load_extension_enable` for an example. + #[cfg(feature = "load_extension")] + pub fn load_extension_disable(&self) -> SqliteResult<()> { + self.db.borrow_mut().enable_load_extension(0) + } + + /// Load the SQLite extension at `dylib_path`. `dylib_path` is passed through to + /// `sqlite3_load_extension`, which may attempt OS-specific modifications if the file + /// cannot be loaded directly. + /// + /// If `entry_point` is `None`, SQLite will attempt to find the entry point. If it is not + /// `None`, the entry point will be passed through to `sqlite3_load_extension`. + #[cfg(feature = "load_extension")] + pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> SqliteResult<()> { + self.db.borrow_mut().load_extension(dylib_path, entry_point) + } + fn decode_result(&self, code: c_int) -> SqliteResult<()> { self.db.borrow_mut().decode_result(code) } @@ -392,6 +444,16 @@ impl InnerSqliteConnection { } } + unsafe fn decode_result_with_errmsg(&self, code: c_int, errmsg: *mut c_char) -> SqliteResult<()> { + if code == ffi::SQLITE_OK { + Ok(()) + } else { + let message = errmsg_to_string(&*errmsg); + ffi::sqlite3_free(errmsg as *mut c_void); + Err(SqliteError{ code: code, message: message }) + } + } + fn close(&mut self) -> SqliteResult<()> { let r = unsafe { ffi::sqlite3_close(self.db) }; self.db = ptr::null_mut(); @@ -403,13 +465,28 @@ impl InnerSqliteConnection { unsafe { let mut errmsg: *mut c_char = mem::uninitialized(); let r = ffi::sqlite3_exec(self.db, c_sql.as_ptr(), None, ptr::null_mut(), &mut errmsg); - if r == ffi::SQLITE_OK { - Ok(()) + self.decode_result_with_errmsg(r, errmsg) + } + } + + #[cfg(feature = "load_extension")] + fn enable_load_extension(&mut self, onoff: c_int) -> SqliteResult<()> { + let r = unsafe { ffi::sqlite3_enable_load_extension(self.db, onoff) }; + self.decode_result(r) + } + + #[cfg(feature = "load_extension")] + fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> SqliteResult<()> { + let dylib_str = try!(path_to_cstring(dylib_path)); + unsafe { + let mut errmsg: *mut c_char = mem::uninitialized(); + let r = if let Some(entry_point) = entry_point { + let c_entry = try!(str_to_cstring(entry_point)); + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), c_entry.as_ptr(), &mut errmsg) } else { - let message = errmsg_to_string(&*errmsg); - ffi::sqlite3_free(errmsg as *mut c_void); - Err(SqliteError{ code: r, message: message }) - } + ffi::sqlite3_load_extension(self.db, dylib_str.as_ptr(), ptr::null(), &mut errmsg) + }; + self.decode_result_with_errmsg(r, errmsg) } }