diff --git a/src/blob.rs b/src/blob.rs index b4f702e..c7e3b51 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -1,10 +1,11 @@ //! incremental BLOB I/O use std::io; +use std::cmp::min; use std::mem; use std::ptr; use super::ffi; -use {Error, Result, Connection, DatabaseName}; +use {Result, Connection, DatabaseName}; /// Handle to an open BLOB pub struct Blob<'conn> { @@ -96,20 +97,10 @@ impl<'conn> io::Read for Blob<'conn> { /// /// # Failure /// - /// Will return `Err` if `buf` length > i32 max value or if the underlying SQLite read call fails. + /// Will return `Err` if the underlying SQLite read call fails. fn read(&mut self, buf: &mut [u8]) -> io::Result { - if buf.len() > ::std::i32::MAX as usize { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - Error { - code: ffi::SQLITE_TOOBIG, - message: "buffer too long".to_string(), - })); - } - let mut n = buf.len() as i32; - let size = self.size(); - if self.pos + n > size { - n = size - self.pos; - } + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; if n <= 0 { return Ok(0); } @@ -129,32 +120,15 @@ impl<'conn> io::Read for Blob<'conn> { impl<'conn> io::Write for Blob<'conn> { /// Write data into a BLOB incrementally /// - /// This function may only modify the contents of the BLOB; it is not possible to increase the size of a BLOB using this API. + /// This function may only modify the contents of the BLOB; it is not possible to increase + /// the size of a BLOB using this API. /// /// # Failure /// - /// Will return `Err` if `buf` length > i32 max value or if `buf` length + offset > BLOB size - /// or if the underlying SQLite write call fails. + /// Will return `Err` if the underlying SQLite write call fails. fn write(&mut self, buf: &[u8]) -> io::Result { - if buf.len() > ::std::i32::MAX as usize { - return Err(io::Error::new(io::ErrorKind::InvalidInput, - Error { - code: ffi::SQLITE_TOOBIG, - message: "buffer too long".to_string(), - })); - } - let n = buf.len() as i32; - let size = self.size(); - if self.pos + n > size { - return Err(io::Error::new(io::ErrorKind::Other, - Error { - code: ffi::SQLITE_MISUSE, - message: format!("pos = {} + n = {} > size = {}", - self.pos, - n, - size), - })); - } + let max_allowed_len = (self.size() - self.pos) as usize; + let n = min(buf.len(), max_allowed_len) as i32; if n <= 0 { return Ok(0); } @@ -206,7 +180,7 @@ impl<'conn> Drop for Blob<'conn> { #[cfg(test)] mod test { - use std::io::{BufReader, BufRead, Read, Write, Seek, SeekFrom}; + use std::io::{BufReader, BufRead, BufWriter, Read, Write, Seek, SeekFrom}; use {Connection, DatabaseName, Result}; #[cfg_attr(rustfmt, rustfmt_skip)] @@ -227,8 +201,8 @@ mod test { let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false).unwrap(); assert_eq!(4, blob.write(b"Clob").unwrap()); - assert!(blob.write(b"5678901").is_err()); // cannot write past 10 - assert_eq!(4, blob.write(b"5678").unwrap()); + assert_eq!(6, blob.write(b"567890xxxxxx").unwrap()); // cannot write past 10 + assert_eq!(0, blob.write(b"5678").unwrap()); // still cannot write past 10 blob.reopen(rowid).unwrap(); blob.close().unwrap(); @@ -238,20 +212,21 @@ mod test { assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); assert_eq!(&bytes, b"Clob5"); assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); - assert_eq!(&bytes, b"678\0\0"); + assert_eq!(&bytes, b"67890"); assert_eq!(0, blob.read(&mut bytes[..]).unwrap()); blob.seek(SeekFrom::Start(2)).unwrap(); assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); assert_eq!(&bytes, b"ob567"); - blob.seek(SeekFrom::Current(-6)).unwrap(); - assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); - assert_eq!(&bytes, b"lob56"); + // only first 4 bytes of `bytes` should be read into + blob.seek(SeekFrom::Current(-1)).unwrap(); + assert_eq!(4, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(&bytes, b"78907"); blob.seek(SeekFrom::End(-6)).unwrap(); assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); - assert_eq!(&bytes, b"5678\0"); + assert_eq!(&bytes, b"56789"); blob.reopen(rowid).unwrap(); assert_eq!(5, blob.read(&mut bytes[..]).unwrap()); @@ -280,4 +255,44 @@ mod test { assert_eq!(2, reader.read_line(&mut line).unwrap()); assert_eq!("\0\0", line); } + + #[test] + fn test_blob_in_bufwriter() { + let (db, rowid) = db_with_test_blob().unwrap(); + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false).unwrap(); + let mut writer = BufWriter::new(blob); + + // trying to write too much and then flush should fail + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert_eq!(8, writer.write(b"01234567").unwrap()); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false).unwrap(); + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"0123456701", &bytes); + } + + { + let blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false).unwrap(); + let mut writer = BufWriter::new(blob); + + // trying to write_all too much should fail + writer.write_all(b"aaaaaaaaaabbbbb").unwrap(); + assert!(writer.flush().is_err()); + } + + { + // ... but it should've written the first 10 bytes + let mut blob = db.blob_open(DatabaseName::Main, "test", "content", rowid, false).unwrap(); + let mut bytes = [0u8; 10]; + assert_eq!(10, blob.read(&mut bytes[..]).unwrap()); + assert_eq!(b"aaaaaaaaaa", &bytes); + } + } }