2023-02-13 14:36:50 +11:00
|
|
|
use crate::cacheable::Cacheable;
|
2023-02-15 19:06:34 +11:00
|
|
|
use crate::key::CacheKey;
|
2023-02-13 14:36:50 +11:00
|
|
|
use platform_dirs::AppDirs;
|
|
|
|
use rusqlite::{params, Connection, DatabaseName};
|
|
|
|
use std::error::Error;
|
|
|
|
use std::path::PathBuf;
|
|
|
|
|
2023-02-16 09:40:24 +11:00
|
|
|
pub(crate) fn get_cache_dir() -> Result<PathBuf, Box<dyn Error>> {
|
2023-02-13 14:36:50 +11:00
|
|
|
let cache_dir =
|
|
|
|
if let Some(cache_dir) = AppDirs::new(Some("librashader"), false).map(|a| a.cache_dir) {
|
|
|
|
cache_dir
|
|
|
|
} else {
|
|
|
|
let mut current_dir = std::env::current_dir()?;
|
|
|
|
current_dir.push("librashader");
|
|
|
|
current_dir
|
|
|
|
};
|
|
|
|
|
|
|
|
std::fs::create_dir_all(&cache_dir)?;
|
|
|
|
|
|
|
|
Ok(cache_dir)
|
|
|
|
}
|
|
|
|
|
2023-02-16 09:40:24 +11:00
|
|
|
pub(crate) fn get_cache() -> Result<Connection, Box<dyn Error>> {
|
2023-02-13 14:36:50 +11:00
|
|
|
let cache_dir = get_cache_dir()?;
|
|
|
|
let mut conn = Connection::open(&cache_dir.join("librashader.db"))?;
|
|
|
|
|
|
|
|
let tx = conn.transaction()?;
|
|
|
|
tx.pragma_update(Some(DatabaseName::Main), "journal_mode", "wal2")?;
|
|
|
|
tx.execute(
|
|
|
|
r#"create table if not exists cache (
|
|
|
|
type text not null,
|
|
|
|
id blob not null,
|
|
|
|
value blob not null unique,
|
|
|
|
primary key (id, type)
|
|
|
|
)"#,
|
|
|
|
[],
|
|
|
|
)?;
|
|
|
|
tx.commit()?;
|
|
|
|
Ok(conn)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub(crate) fn get_blob(
|
|
|
|
conn: &Connection,
|
|
|
|
index: &str,
|
|
|
|
key: &[u8],
|
|
|
|
) -> Result<Vec<u8>, Box<dyn Error>> {
|
|
|
|
let value = conn.query_row(
|
|
|
|
&*format!("select value from cache where (type = (?1) and id = (?2))"),
|
|
|
|
params![index, key],
|
|
|
|
|row| row.get(0),
|
|
|
|
)?;
|
|
|
|
Ok(value)
|
|
|
|
}
|
|
|
|
|
|
|
|
pub(crate) fn set_blob(conn: &Connection, index: &str, key: &[u8], value: &[u8]) {
|
|
|
|
match conn.execute(
|
|
|
|
&*format!("insert or replace into cache (type, id, value) values (?1, ?2, ?3)"),
|
|
|
|
params![index, key, value],
|
|
|
|
) {
|
|
|
|
Ok(_) => return,
|
|
|
|
Err(e) => println!("err: {:?}", e),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-16 09:40:24 +11:00
|
|
|
/// Cache a shader object (usually bytecode) created by the keyed objects.
|
|
|
|
///
|
|
|
|
/// - `factory` is the function that compiles the values passed as keys to a shader object.
|
|
|
|
/// - `load` tries to load a compiled shader object to a driver-specialized result.
|
|
|
|
pub fn cache_shader_object<E, T, R, H, const KEY_SIZE: usize>(
|
2023-02-13 14:36:50 +11:00
|
|
|
index: &str,
|
|
|
|
keys: &[H; KEY_SIZE],
|
|
|
|
factory: impl FnOnce(&[H; KEY_SIZE]) -> Result<T, E>,
|
2023-02-16 09:40:24 +11:00
|
|
|
load: impl Fn(T) -> Result<R, E>,
|
|
|
|
bypass_cache: bool,
|
2023-02-13 14:36:50 +11:00
|
|
|
) -> Result<R, E>
|
|
|
|
where
|
|
|
|
H: CacheKey,
|
|
|
|
T: Cacheable,
|
|
|
|
{
|
2023-02-16 09:40:24 +11:00
|
|
|
if bypass_cache {
|
|
|
|
return Ok(load(factory(keys)?)?);
|
2023-02-13 14:36:50 +11:00
|
|
|
}
|
|
|
|
|
|
|
|
let cache = get_cache();
|
|
|
|
|
|
|
|
let Ok(cache) = cache else {
|
2023-02-16 09:40:24 +11:00
|
|
|
return Ok(load(factory(keys)?)?);
|
2023-02-13 14:36:50 +11:00
|
|
|
};
|
|
|
|
|
|
|
|
let hashkey = {
|
|
|
|
let mut hasher = blake3::Hasher::new();
|
|
|
|
for subkeys in keys {
|
|
|
|
hasher.update(subkeys.hash_bytes());
|
|
|
|
}
|
|
|
|
let hash = hasher.finalize();
|
|
|
|
hash
|
|
|
|
};
|
|
|
|
|
|
|
|
'attempt: {
|
|
|
|
if let Ok(blob) = get_blob(&cache, index, hashkey.as_bytes()) {
|
2023-02-16 09:40:24 +11:00
|
|
|
let cached = T::from_bytes(&blob).map(&load);
|
2023-02-13 14:36:50 +11:00
|
|
|
|
|
|
|
match cached {
|
|
|
|
None => break 'attempt,
|
|
|
|
Some(Err(_)) => break 'attempt,
|
|
|
|
Some(Ok(res)) => return Ok(res),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
|
|
|
|
let blob = factory(keys)?;
|
|
|
|
|
|
|
|
if let Some(slice) = T::to_bytes(&blob) {
|
|
|
|
set_blob(&cache, index, hashkey.as_bytes(), &slice);
|
|
|
|
}
|
2023-02-16 09:40:24 +11:00
|
|
|
Ok(load(blob)?)
|
2023-02-13 14:36:50 +11:00
|
|
|
}
|
|
|
|
|
2023-02-16 09:40:24 +11:00
|
|
|
/// Cache a pipeline state object.
|
|
|
|
///
|
|
|
|
/// Keys are not used to create the object and are only used to uniquely identify the pipeline state.
|
|
|
|
///
|
|
|
|
/// - `restore_pipeline` tries to restore the pipeline with either a cached binary pipeline state
|
|
|
|
/// cache, or create a new pipeline if no cached value is available.
|
|
|
|
/// - `fetch_pipeline_state` fetches the new pipeline state cache after the pipeline was created.
|
2023-02-13 14:36:50 +11:00
|
|
|
pub fn cache_pipeline<E, T, R, const KEY_SIZE: usize>(
|
|
|
|
index: &str,
|
|
|
|
keys: &[&dyn CacheKey; KEY_SIZE],
|
2023-02-16 09:40:24 +11:00
|
|
|
restore_pipeline: impl Fn(Option<Vec<u8>>) -> Result<R, E>,
|
|
|
|
fetch_pipeline_state: impl FnOnce(&R) -> Result<T, E>,
|
|
|
|
bypass_cache: bool,
|
2023-02-13 14:36:50 +11:00
|
|
|
) -> Result<R, E>
|
|
|
|
where
|
|
|
|
T: Cacheable,
|
|
|
|
{
|
2023-02-16 09:40:24 +11:00
|
|
|
if bypass_cache {
|
|
|
|
return Ok(restore_pipeline(None)?);
|
2023-02-13 14:36:50 +11:00
|
|
|
}
|
|
|
|
|
|
|
|
let cache = get_cache();
|
|
|
|
|
|
|
|
let Ok(cache) = cache else {
|
2023-02-16 09:40:24 +11:00
|
|
|
return Ok(restore_pipeline(None)?);
|
2023-02-13 14:36:50 +11:00
|
|
|
};
|
|
|
|
|
|
|
|
let hashkey = {
|
|
|
|
let mut hasher = blake3::Hasher::new();
|
|
|
|
for subkeys in keys {
|
|
|
|
hasher.update(subkeys.hash_bytes());
|
|
|
|
}
|
|
|
|
let hash = hasher.finalize();
|
|
|
|
hash
|
|
|
|
};
|
|
|
|
|
|
|
|
let pipeline = 'attempt: {
|
|
|
|
if let Ok(blob) = get_blob(&cache, index, hashkey.as_bytes()) {
|
2023-02-16 09:40:24 +11:00
|
|
|
let cached = restore_pipeline(Some(blob));
|
2023-02-13 14:36:50 +11:00
|
|
|
match cached {
|
|
|
|
Ok(res) => {
|
|
|
|
break 'attempt res;
|
|
|
|
}
|
|
|
|
_ => (),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-02-16 09:40:24 +11:00
|
|
|
restore_pipeline(None)?
|
2023-02-13 14:36:50 +11:00
|
|
|
};
|
|
|
|
|
|
|
|
// update the pso every time just in case.
|
2023-02-16 09:40:24 +11:00
|
|
|
if let Ok(state) = fetch_pipeline_state(&pipeline) {
|
2023-02-13 14:36:50 +11:00
|
|
|
if let Some(slice) = T::to_bytes(&state) {
|
|
|
|
set_blob(&cache, index, hashkey.as_bytes(), &slice);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
Ok(pipeline)
|
|
|
|
}
|