diff --git a/src/statement.rs b/src/statement.rs index 2769d6f..fde89ec 100644 --- a/src/statement.rs +++ b/src/statement.rs @@ -1,6 +1,5 @@ -use std::ffi::CStr; use std::iter::IntoIterator; -use std::os::raw::{c_char, c_int, c_void}; +use std::os::raw::{c_int, c_void}; #[cfg(feature = "array")] use std::rc::Rc; use std::slice::from_raw_parts; @@ -649,7 +648,8 @@ impl Statement<'_> { pub(crate) fn value_ref(&self, col: usize) -> ValueRef<'_> { let raw = unsafe { self.stmt.ptr() }; - match self.stmt.column_type(col) { + let kind = self.stmt.column_type(col); + match kind { ffi::SQLITE_NULL => ValueRef::Null, ffi::SQLITE_INTEGER => { ValueRef::Integer(unsafe { ffi::sqlite3_column_int64(raw, col as c_int) }) @@ -657,20 +657,7 @@ impl Statement<'_> { ffi::SQLITE_FLOAT => { ValueRef::Real(unsafe { ffi::sqlite3_column_double(raw, col as c_int) }) } - ffi::SQLITE_TEXT => { - let s = unsafe { - let text = ffi::sqlite3_column_text(raw, col as c_int); - assert!( - !text.is_null(), - "unexpected SQLITE_TEXT column type with NULL data" - ); - CStr::from_ptr(text as *const c_char) - }; - - let s = s.to_bytes(); - ValueRef::Text(s) - } - ffi::SQLITE_BLOB => { + ffi::SQLITE_TEXT | ffi::SQLITE_BLOB => { let (blob, len) = unsafe { ( ffi::sqlite3_column_blob(raw, col as c_int), @@ -682,16 +669,23 @@ impl Statement<'_> { len >= 0, "unexpected negative return from sqlite3_column_bytes" ); - if len > 0 { + + let bytes = if len > 0 { assert!( !blob.is_null(), - "unexpected SQLITE_BLOB column type with NULL data" + "unexpected SQLITE_BLOB/TEXT column type with NULL data" ); - ValueRef::Blob(unsafe { from_raw_parts(blob as *const u8, len as usize) }) + unsafe { from_raw_parts(blob as *const u8, len as usize) } } else { // The return value from sqlite3_column_blob() for a zero-length BLOB // is a NULL pointer. - ValueRef::Blob(&[]) + &[] + }; + + if kind == ffi::SQLITE_TEXT { + ValueRef::Text(bytes) + } else { + ValueRef::Blob(bytes) } } _ => unreachable!("sqlite3_column_type returned invalid value"), @@ -740,7 +734,7 @@ pub enum StatementStatus { #[cfg(test)] mod test { use crate::types::ToSql; - use crate::{Connection, Error, Result, NO_PARAMS}; + use crate::{params, Connection, Error, Result, NO_PARAMS}; #[test] fn test_execute_named() { @@ -1089,4 +1083,21 @@ mod test { let stmt = conn.prepare(";").unwrap(); assert_eq!(0, stmt.column_count()); } + + #[test] + fn test_nul_in_text_column() { + let db = Connection::open_in_memory().unwrap(); + db.execute_batch("create table tbl(val text);").unwrap(); + + let expected = "a\x00b".to_string(); + db.execute( + "insert into tbl values (cast(? as text))", + params![expected.as_bytes()], + ) + .unwrap(); + let text: String = db + .query_row("select val from tbl", NO_PARAMS, |row| row.get(0)) + .unwrap(); + assert_eq!(text, expected); + } }