diff --git a/src/lib.rs b/src/lib.rs index 30a2c46..8286d7a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -74,8 +74,11 @@ pub use transaction::{SqliteTransactionBehavior, SqliteTransactionImmediate, SqliteTransactionExclusive}; +#[cfg(feature = "load_extension")] pub use load_extension_guard::{SqliteLoadExtensionGuard}; + pub mod types; mod transaction; +#[cfg(feature = "load_extension")] mod load_extension_guard; /// A typedef of the result returned by many methods. pub type SqliteResult = Result; @@ -335,7 +338,8 @@ impl SqliteConnection { db.close() } - /// Enable loading of SQLite extensions. + /// Enable loading of SQLite extensions. Strongly consider using `SqliteLoadExtensionGuard` + /// instead of this function. /// /// ## Example /// @@ -367,6 +371,17 @@ impl SqliteConnection { /// /// If `entry_point` is `None`, SQLite will attempt to find the entry point. If it is not /// `None`, the entry point will be passed through to `sqlite3_load_extension`. + /// + /// ## Example + /// + /// ```rust,no_run + /// # use rusqlite::{SqliteConnection, SqliteResult, SqliteLoadExtensionGuard}; + /// # use std::path::{Path}; + /// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> { + /// let _guard = try!(SqliteLoadExtensionGuard::new(conn)); + /// + /// conn.load_extension(Path::new("my_sqlite_extension"), None) + /// } #[cfg(feature = "load_extension")] pub fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> SqliteResult<()> { self.db.borrow_mut().load_extension(dylib_path, entry_point) diff --git a/src/load_extension_guard.rs b/src/load_extension_guard.rs new file mode 100644 index 0000000..82a7c89 --- /dev/null +++ b/src/load_extension_guard.rs @@ -0,0 +1,34 @@ +use {SqliteResult, SqliteConnection}; + +/// RAII guard temporarily enabling SQLite extensions to be loaded. +/// +/// ## Example +/// +/// ```rust,no_run +/// # use rusqlite::{SqliteConnection, SqliteResult, SqliteLoadExtensionGuard}; +/// # use std::path::{Path}; +/// fn load_my_extension(conn: &SqliteConnection) -> SqliteResult<()> { +/// let _guard = try!(SqliteLoadExtensionGuard::new(conn)); +/// +/// conn.load_extension(Path::new("my_sqlite_extension"), None) +/// } +/// ``` +pub struct SqliteLoadExtensionGuard<'conn> { + conn: &'conn SqliteConnection, +} + +impl<'conn> SqliteLoadExtensionGuard<'conn> { + /// Attempt to enable loading extensions. Loading extensions will be disabled when this + /// guard goes out of scope. Cannot be meaningfully nested. + pub fn new(conn: &SqliteConnection) -> SqliteResult { + conn.load_extension_enable().map(|_| SqliteLoadExtensionGuard{ conn: conn }) + } +} + +#[unsafe_destructor] +#[allow(unused_must_use)] +impl<'conn> Drop for SqliteLoadExtensionGuard<'conn> { + fn drop(&mut self) { + self.conn.load_extension_disable(); + } +}