Skip to content

Commit

Permalink
Merge pull request #45 from ShoyuVanilla/master
Browse files Browse the repository at this point in the history
Fix #39
  • Loading branch information
jerebtw committed May 15, 2023
2 parents bcc7ca8 + 2671983 commit e686864
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ path = "tests/test.rs"

[dependencies]
r2d2 = "0.8"
uuid = { version = "1.0", features = ["v4", "fast-rng"] }

[dependencies.rusqlite]
version = "0.29"
Expand Down
29 changes: 25 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@
//! .unwrap()
//! }
//! ```
pub use rusqlite;
use rusqlite::{Connection, Error, OpenFlags};
use std::fmt;
use std::path::{Path, PathBuf};
pub use rusqlite;
use std::sync::Mutex;
use uuid::Uuid;

#[derive(Debug)]
enum Source {
File(PathBuf),
Memory,
Memory(String),
}

type InitFn = dyn Fn(&mut Connection) -> Result<(), rusqlite::Error> + Send + Sync + 'static;
Expand All @@ -59,6 +61,7 @@ pub struct SqliteConnectionManager {
source: Source,
flags: OpenFlags,
init: Option<Box<InitFn>>,
_persist: Mutex<Option<Connection>>,
}

impl fmt::Debug for SqliteConnectionManager {
Expand All @@ -80,15 +83,17 @@ impl SqliteConnectionManager {
source: Source::File(path.as_ref().to_path_buf()),
flags: OpenFlags::default(),
init: None,
_persist: Mutex::new(None),
}
}

/// Creates a new `SqliteConnectionManager` from memory.
pub fn memory() -> Self {
Self {
source: Source::Memory,
source: Source::Memory(Uuid::new_v4().to_string()),
flags: OpenFlags::default(),
init: None,
_persist: Mutex::new(None),
}
}

Expand Down Expand Up @@ -130,7 +135,23 @@ impl r2d2::ManageConnection for SqliteConnectionManager {
fn connect(&self) -> Result<Connection, Error> {
match self.source {
Source::File(ref path) => Connection::open_with_flags(path, self.flags),
Source::Memory => Connection::open_in_memory_with_flags(self.flags),
Source::Memory(ref id) => {
let connection = || {
Connection::open_with_flags(
format!("file:{}?mode=memory&cache=shared", id),
self.flags,
)
};

{
let mut persist = self._persist.lock().unwrap();
if persist.is_none() {
*persist = Some(connection()?);
}
}

connection()
}
}
.map_err(Into::into)
.and_then(|mut c| match self.init {
Expand Down
74 changes: 74 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,77 @@ fn test_with_init() {
.unwrap();
assert_eq!(db_version, 123);
}

#[test]
fn test_in_memory_db_is_shared() {
let manager = SqliteConnectionManager::memory();
let pool = r2d2::Pool::builder().max_size(10).build(manager).unwrap();

pool.get()
.unwrap()
.execute("CREATE TABLE IF NOT EXISTS foo (bar INTEGER)", [])
.unwrap();

(0..10)
.map(|i: i32| {
let pool = pool.clone();
std::thread::spawn(move || {
let conn = pool.get().unwrap();
conn.execute("INSERT INTO foo (bar) VALUES (?)", [i])
.unwrap();
})
})
.collect::<Vec<_>>()
.into_iter()
.try_for_each(std::thread::JoinHandle::join)
.unwrap();

let conn = pool.get().unwrap();
let mut stmt = conn.prepare("SELECT bar from foo").unwrap();
let mut rows: Vec<i32> = stmt
.query_map([], |row| row.get(0))
.unwrap()
.into_iter()
.flatten()
.collect();
rows.sort_unstable();
assert_eq!(rows, (0..10).collect::<Vec<_>>());
}

#[test]
fn test_different_in_memory_dbs_are_not_shared() {
let manager1 = SqliteConnectionManager::memory();
let pool1 = r2d2::Pool::new(manager1).unwrap();
let manager2 = SqliteConnectionManager::memory();
let pool2 = r2d2::Pool::new(manager2).unwrap();

pool1
.get()
.unwrap()
.execute_batch("CREATE TABLE foo (bar INTEGER)")
.unwrap();
let result = pool2
.get()
.unwrap()
.execute_batch("CREATE TABLE foo (bar INTEGER)");

assert!(result.is_ok());
}

#[test]
fn test_in_memory_db_persists() {
let manager = SqliteConnectionManager::memory();

{
// Normally, `r2d2::Pool` won't drop connection unless timed-out or broken.
// So let's drop managed connection instead.
let conn = manager.connect().unwrap();
conn.execute_batch("CREATE TABLE foo (bar INTEGER)")
.unwrap();
}

let conn = manager.connect().unwrap();
let mut stmt = conn.prepare("SELECT * from foo").unwrap();
let result = stmt.execute([]);
assert!(result.is_ok());
}

0 comments on commit e686864

Please sign in to comment.