diff --git a/examples/loadable_extension.rs b/examples/loadable_extension.rs index 809d42f..e913240 100644 --- a/examples/loadable_extension.rs +++ b/examples/loadable_extension.rs @@ -25,7 +25,9 @@ pub extern "C" fn sqlite3_extension_init( pz_err_msg: *mut *mut c_char, p_api: *mut ffi::sqlite3_api_routines, ) -> c_int { - if let Err(err) = extension_init(db, p_api) { + if p_api.is_null() { + return ffi::SQLITE_ERROR; + } else if let Err(err) = extension_init(db, p_api) { return unsafe { to_sqlite_error(&err, pz_err_msg) }; } ffi::SQLITE_OK diff --git a/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0_ext.rs b/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0_ext.rs index f6c4cb6..a8e1afc 100644 --- a/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0_ext.rs +++ b/libsqlite3-sys/bindgen-bindings/bindgen_3.14.0_ext.rs @@ -6267,9 +6267,7 @@ pub unsafe fn sqlite3_expanded_sql( pub unsafe fn rusqlite_extension_init2( p_api: *mut sqlite3_api_routines, ) -> ::std::result::Result<(), crate::InitError> { - if p_api.is_null() { - return Err(crate::InitError::NullApiPointer); - } + __SQLITE3_MALLOC.store((*p_api).malloc, ::atomic::Ordering::Release); if let Some(fun) = (*p_api).libversion_number { let version = fun(); if SQLITE_VERSION_NUMBER > version { @@ -6371,7 +6369,6 @@ pub unsafe fn rusqlite_extension_init2( __SQLITE3_LIBVERSION.store((*p_api).libversion, ::atomic::Ordering::Release); __SQLITE3_LIBVERSION_NUMBER .store((*p_api).libversion_number, ::atomic::Ordering::Release); - __SQLITE3_MALLOC.store((*p_api).malloc, ::atomic::Ordering::Release); __SQLITE3_OPEN.store((*p_api).open, ::atomic::Ordering::Release); __SQLITE3_OPEN16.store((*p_api).open16, ::atomic::Ordering::Release); __SQLITE3_PREPARE.store((*p_api).prepare, ::atomic::Ordering::Release); diff --git a/libsqlite3-sys/build.rs b/libsqlite3-sys/build.rs index 49efffb..dd42653 100644 --- a/libsqlite3-sys/build.rs +++ b/libsqlite3-sys/build.rs @@ -702,6 +702,7 @@ mod loadable_extension { let sqlite3_api_routines_ident = sqlite3_api_routines.ident; let p_api = quote::format_ident!("p_api"); let mut stores = Vec::new(); + let mut malloc = Vec::new(); // (2) `#define sqlite3_xyz sqlite3_api->abc` => `pub unsafe fn // sqlite3_xyz(args) -> ty {...}` for each `abc` field: for field in sqlite3_api_routines.fields { @@ -764,7 +765,12 @@ mod loadable_extension { &syn::parse2(tokens).expect("could not parse quote output"), )); output.push('\n'); - stores.push(quote::quote! { + if name == "malloc" { + &mut malloc + } else { + &mut stores + } + .push(quote::quote! { #ptr_name.store( (*#p_api).#ident, ::atomic::Ordering::Release, @@ -775,9 +781,7 @@ mod loadable_extension { let tokens = quote::quote! { /// Like SQLITE_EXTENSION_INIT2 macro pub unsafe fn rusqlite_extension_init2(#p_api: *mut #sqlite3_api_routines_ident) -> ::std::result::Result<(),crate::InitError> { - if #p_api.is_null() { - return Err(crate::InitError::NullApiPointer); - } + #(#malloc)* // sqlite3_malloc needed by to_sqlite_error if let Some(fun) = (*#p_api).libversion_number { let version = fun(); if SQLITE_VERSION_NUMBER > version { diff --git a/libsqlite3-sys/sqlite3/bindgen_bundled_version_ext.rs b/libsqlite3-sys/sqlite3/bindgen_bundled_version_ext.rs index a97023e..b4f9c5d 100644 --- a/libsqlite3-sys/sqlite3/bindgen_bundled_version_ext.rs +++ b/libsqlite3-sys/sqlite3/bindgen_bundled_version_ext.rs @@ -7463,9 +7463,7 @@ pub unsafe fn sqlite3_is_interrupted(arg1: *mut sqlite3) -> ::std::os::raw::c_in pub unsafe fn rusqlite_extension_init2( p_api: *mut sqlite3_api_routines, ) -> ::std::result::Result<(), crate::InitError> { - if p_api.is_null() { - return Err(crate::InitError::NullApiPointer); - } + __SQLITE3_MALLOC.store((*p_api).malloc, ::atomic::Ordering::Release); if let Some(fun) = (*p_api).libversion_number { let version = fun(); if SQLITE_VERSION_NUMBER > version { @@ -7567,7 +7565,6 @@ pub unsafe fn rusqlite_extension_init2( __SQLITE3_LIBVERSION.store((*p_api).libversion, ::atomic::Ordering::Release); __SQLITE3_LIBVERSION_NUMBER .store((*p_api).libversion_number, ::atomic::Ordering::Release); - __SQLITE3_MALLOC.store((*p_api).malloc, ::atomic::Ordering::Release); __SQLITE3_OPEN.store((*p_api).open, ::atomic::Ordering::Release); __SQLITE3_OPEN16.store((*p_api).open16, ::atomic::Ordering::Release); __SQLITE3_PREPARE.store((*p_api).prepare, ::atomic::Ordering::Release); diff --git a/libsqlite3-sys/src/error.rs b/libsqlite3-sys/src/error.rs index 4e91547..f93d20a 100644 --- a/libsqlite3-sys/src/error.rs +++ b/libsqlite3-sys/src/error.rs @@ -275,8 +275,6 @@ pub fn code_to_str(code: c_int) -> &'static str { #[derive(Clone, Copy, Debug, PartialEq, Eq)] #[non_exhaustive] pub enum InitError { - /// Invalid sqlite3_api_routines pointer - NullApiPointer, /// Version mismatch between the extension and the SQLite3 library VersionMismatch { compile_time: i32, runtime: i32 }, /// Invalid function pointer in one of sqlite3_api_routines fields @@ -286,9 +284,6 @@ pub enum InitError { impl ::std::fmt::Display for InitError { fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { match *self { - InitError::NullApiPointer => { - write!(f, "Invalid sqlite3_api_routines pointer") - } InitError::VersionMismatch { compile_time, runtime,