use super::models::DatabaseError; use ariadne::ids::base62_impl::{parse_base62, to_base62}; use chrono::{TimeZone, Utc}; use dashmap::DashMap; use deadpool_redis::{Config, Runtime}; use futures::future::Either; use prometheus::{IntGauge, Registry}; use redis::{Cmd, ExistenceCheck, SetExpiry, SetOptions, cmd}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::{Debug, Display}; use std::future::Future; use std::hash::Hash; use std::time::Duration; const DEFAULT_EXPIRY: i64 = 60 * 60 * 12; // 12 hours const ACTUAL_EXPIRY: i64 = 60 * 30; // 30 minutes #[derive(Clone)] pub struct RedisPool { pub url: String, pub pool: deadpool_redis::Pool, meta_namespace: String, } pub struct RedisConnection { pub connection: deadpool_redis::Connection, meta_namespace: String, } impl RedisPool { // initiate a new redis pool // testing pool uses a hashmap to mimic redis behaviour for very small data sizes (ie: tests) // PANICS: production pool will panic if redis url is not set pub fn new(meta_namespace: Option) -> Self { let url = dotenvy::var("REDIS_URL").expect("Redis URL not set"); let pool = Config::from_url(url.clone()) .builder() .expect("Error building Redis pool") .max_size( dotenvy::var("REDIS_MAX_CONNECTIONS") .ok() .and_then(|x| x.parse().ok()) .unwrap_or(10000), ) .runtime(Runtime::Tokio1) .build() .expect("Redis connection failed"); RedisPool { url, pool, meta_namespace: meta_namespace.unwrap_or("".to_string()), } } pub async fn register_and_set_metrics( &self, registry: &Registry, ) -> Result<(), prometheus::Error> { let redis_max_size = IntGauge::new( "labrinth_redis_pool_max_size", "Maximum size of Redis pool", )?; let redis_size = IntGauge::new( "labrinth_redis_pool_size", "Current size of Redis pool", )?; let redis_available = IntGauge::new( "labrinth_redis_pool_available", "Available connections in Redis pool", )?; let redis_waiting = IntGauge::new( "labrinth_redis_pool_waiting", "Number of futures waiting for a Redis connection", )?; registry.register(Box::new(redis_max_size.clone()))?; registry.register(Box::new(redis_size.clone()))?; registry.register(Box::new(redis_available.clone()))?; registry.register(Box::new(redis_waiting.clone()))?; let redis_pool_ref = self.pool.clone(); tokio::spawn(async move { loop { let status = redis_pool_ref.status(); redis_max_size.set(status.max_size as i64); redis_size.set(status.size as i64); redis_available.set(status.available as i64); redis_waiting.set(status.waiting as i64); tokio::time::sleep(Duration::from_secs(5)).await; } }); Ok(()) } pub async fn connect(&self) -> Result { Ok(RedisConnection { connection: self.pool.get().await?, meta_namespace: self.meta_namespace.clone(), }) } pub async fn get_cached_keys( &self, namespace: &str, keys: &[K], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, DatabaseError>>, T: Serialize + DeserializeOwned, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize + Debug, { Ok(self .get_cached_keys_raw(namespace, keys, closure) .await? .into_iter() .map(|x| x.1) .collect()) } pub async fn get_cached_keys_raw( &self, namespace: &str, keys: &[K], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, DatabaseError>>, T: Serialize + DeserializeOwned, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize + Debug, { self.get_cached_keys_raw_with_slug( namespace, None, false, keys, |ids| async move { Ok(closure(ids) .await? .into_iter() .map(|(key, val)| (key, (None::, val))) .collect()) }, ) .await } pub async fn get_cached_keys_with_slug( &self, namespace: &str, slug_namespace: &str, case_sensitive: bool, keys: &[I], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, T)>, DatabaseError>>, T: Serialize + DeserializeOwned, I: Display + Hash + Eq + PartialEq + Clone + Debug, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize, S: Display + Clone + DeserializeOwned + Serialize + Debug, { Ok(self .get_cached_keys_raw_with_slug( namespace, Some(slug_namespace), case_sensitive, keys, closure, ) .await? .into_iter() .map(|x| x.1) .collect()) } pub async fn get_cached_keys_raw_with_slug( &self, namespace: &str, slug_namespace: Option<&str>, case_sensitive: bool, keys: &[I], closure: F, ) -> Result, DatabaseError> where F: FnOnce(Vec) -> Fut, Fut: Future, T)>, DatabaseError>>, T: Serialize + DeserializeOwned, I: Display + Hash + Eq + PartialEq + Clone + Debug, K: Display + Hash + Eq + PartialEq + Clone + DeserializeOwned + Serialize, S: Display + Clone + DeserializeOwned + Serialize + Debug, { let ids = keys .iter() .map(|x| (x.to_string(), x.clone())) .collect::>(); if ids.is_empty() { return Ok(HashMap::new()); } let get_cached_values = |ids: DashMap| async move { let slug_ids = if let Some(slug_namespace) = slug_namespace { let mut connection = self.pool.get().await?; cmd("MGET") .arg( ids.iter() .map(|x| { format!( "{}_{slug_namespace}:{}", self.meta_namespace, if case_sensitive { x.value().to_string() } else { x.value().to_string().to_lowercase() } ) }) .collect::>(), ) .query_async::>>(&mut connection) .await? .into_iter() .flatten() .collect::>() } else { Vec::new() }; let mut connection = self.pool.get().await?; let cached_values = cmd("MGET") .arg( ids.iter() .map(|x| x.value().to_string()) .chain(ids.iter().filter_map(|x| { parse_base62(&x.value().to_string()) .ok() .map(|x| x.to_string()) })) .chain(slug_ids) .map(|x| { format!("{}_{namespace}:{x}", self.meta_namespace) }) .collect::>(), ) .query_async::>>(&mut connection) .await? .into_iter() .filter_map(|x| { x.and_then(|val| { serde_json::from_str::>(&val).ok() }) .map(|val| (val.key.clone(), val)) }) .collect::>(); Ok::<_, DatabaseError>((cached_values, ids)) }; let current_time = Utc::now(); let mut expired_values = HashMap::new(); let (cached_values_raw, ids) = get_cached_values(ids).await?; let mut cached_values = cached_values_raw .into_iter() .filter_map(|(key, val)| { if Utc.timestamp_opt(val.iat + ACTUAL_EXPIRY, 0).unwrap() < current_time { expired_values.insert(val.key.to_string(), val); None } else { let key_str = val.key.to_string(); ids.remove(&key_str); if let Ok(value) = key_str.parse::() { let base62 = to_base62(value); ids.remove(&base62); } if let Some(ref alias) = val.alias { ids.remove(&alias.to_string()); } Some((key, val)) } }) .collect::>(); let subscribe_ids = DashMap::new(); if !ids.is_empty() { let mut pipe = redis::pipe(); let fetch_ids = ids.iter().map(|x| x.key().clone()).collect::>(); fetch_ids.iter().for_each(|key| { pipe.atomic().set_options( // We store locks in lowercase because they are case insensitive format!( "{}_{namespace}:{}/lock", self.meta_namespace, key.to_lowercase() ), 100, SetOptions::default() .get(true) .conditional_set(ExistenceCheck::NX) .with_expiration(SetExpiry::EX(60)), ); }); let results = { let mut connection = self.pool.get().await?; pipe.query_async::>>(&mut connection) .await? }; for (idx, key) in fetch_ids.into_iter().enumerate() { if let Some(locked) = results.get(idx) && locked.is_none() { continue; } if let Some((key, raw_key)) = ids.remove(&key) { if let Some(val) = expired_values.remove(&key) { if let Some(ref alias) = val.alias { ids.remove(&alias.to_string()); } if let Ok(value) = val.key.to_string().parse::() { let base62 = to_base62(value); ids.remove(&base62); } cached_values.insert(val.key.clone(), val); } else { subscribe_ids.insert(key, raw_key); } } } } let mut fetch_tasks = Vec::new(); if !ids.is_empty() { fetch_tasks.push(Either::Left(async { let fetch_ids = ids.iter().map(|x| x.value().clone()).collect::>(); let vals = closure(fetch_ids).await?; let mut return_values = HashMap::new(); let mut pipe = redis::pipe(); if !vals.is_empty() { for (key, (slug, value)) in vals { let value = RedisValue { key: key.clone(), iat: Utc::now().timestamp(), val: value, alias: slug.clone(), }; pipe.atomic().set_ex( format!( "{}_{namespace}:{key}", self.meta_namespace ), serde_json::to_string(&value)?, DEFAULT_EXPIRY as u64, ); if let Some(slug) = slug { ids.remove(&slug.to_string()); if let Some(slug_namespace) = slug_namespace { let actual_slug = if case_sensitive { slug.to_string() } else { slug.to_string().to_lowercase() }; pipe.atomic().set_ex( format!( "{}_{slug_namespace}:{}", self.meta_namespace, actual_slug ), key.to_string(), DEFAULT_EXPIRY as u64, ); pipe.atomic().del(format!( "{}_{namespace}:{}/lock", // Locks are stored in lowercase self.meta_namespace, actual_slug.to_lowercase() )); } } let key_str = key.to_string(); ids.remove(&key_str); if let Ok(value) = key_str.parse::() { let base62 = to_base62(value); ids.remove(&base62); pipe.atomic().del(format!( "{}_{namespace}:{}/lock", self.meta_namespace, // Locks are stored in lowercase base62.to_lowercase() )); } pipe.atomic().del(format!( "{}_{namespace}:{key}/lock", self.meta_namespace )); return_values.insert(key, value); } } for (key, _) in ids { pipe.atomic().del(format!( "{}_{namespace}:{}/lock", self.meta_namespace, key.to_lowercase() )); pipe.atomic().del(format!( "{}_{namespace}:{key}/lock", self.meta_namespace )); } let mut connection = self.pool.get().await?; pipe.query_async::<()>(&mut connection).await?; Ok(return_values) })); } if !subscribe_ids.is_empty() { fetch_tasks.push(Either::Right(async { let mut interval = tokio::time::interval(Duration::from_millis(100)); let start = Utc::now(); loop { let results = { let mut connection = self.pool.get().await?; cmd("MGET") .arg( subscribe_ids .iter() .map(|x| { format!( "{}_{namespace}:{}/lock", self.meta_namespace, // We lowercase key because locks are stored in lowercase x.key().to_lowercase() ) }) .collect::>(), ) .query_async::>>(&mut connection) .await? }; if results.into_iter().all(|x| x.is_none()) { break; } if (Utc::now() - start) > chrono::Duration::seconds(5) { return Err(DatabaseError::CacheTimeout); } interval.tick().await; } let (return_values, _) = get_cached_values(subscribe_ids).await?; Ok(return_values) })); } if !fetch_tasks.is_empty() { for map in futures::future::try_join_all(fetch_tasks).await? { for (key, value) in map { cached_values.insert(key, value); } } } Ok(cached_values.into_iter().map(|x| (x.0, x.1.val)).collect()) } } impl RedisConnection { pub async fn set( &mut self, namespace: &str, id: &str, data: &str, expiry: Option, ) -> Result<(), DatabaseError> { let mut cmd = cmd("SET"); redis_args( &mut cmd, vec![ format!("{}_{}:{}", self.meta_namespace, namespace, id), data.to_string(), "EX".to_string(), expiry.unwrap_or(DEFAULT_EXPIRY).to_string(), ] .as_slice(), ); redis_execute::<()>(&mut cmd, &mut self.connection).await?; Ok(()) } pub async fn set_serialized_to_json( &mut self, namespace: &str, id: Id, data: D, expiry: Option, ) -> Result<(), DatabaseError> where Id: Display, D: serde::Serialize, { self.set( namespace, &id.to_string(), &serde_json::to_string(&data)?, expiry, ) .await } pub async fn get( &mut self, namespace: &str, id: &str, ) -> Result, DatabaseError> { let mut cmd = cmd("GET"); redis_args( &mut cmd, vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); let res = redis_execute(&mut cmd, &mut self.connection).await?; Ok(res) } pub async fn get_many( &mut self, namespace: &str, ids: &[String], ) -> Result>, DatabaseError> { let mut cmd = cmd("MGET"); redis_args( &mut cmd, ids.iter() .map(|x| format!("{}_{}:{}", self.meta_namespace, namespace, x)) .collect::>() .as_slice(), ); let res = redis_execute(&mut cmd, &mut self.connection).await?; Ok(res) } pub async fn get_deserialized_from_json( &mut self, namespace: &str, id: &str, ) -> Result, DatabaseError> where R: for<'a> serde::Deserialize<'a>, { Ok(self .get(namespace, id) .await? .and_then(|x| serde_json::from_str(&x).ok())) } pub async fn get_many_deserialized_from_json( &mut self, namespace: &str, ids: &[String], ) -> Result>, DatabaseError> where R: for<'a> serde::Deserialize<'a>, { Ok(self .get_many(namespace, ids) .await? .into_iter() .map(|x| x.and_then(|val| serde_json::from_str::(&val).ok())) .collect::>()) } pub async fn delete( &mut self, namespace: &str, id: T1, ) -> Result<(), DatabaseError> where T1: Display, { let mut cmd = cmd("DEL"); redis_args( &mut cmd, vec![format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); redis_execute::<()>(&mut cmd, &mut self.connection).await?; Ok(()) } pub async fn delete_many( &mut self, iter: impl IntoIterator)>, ) -> Result<(), DatabaseError> { let mut cmd = cmd("DEL"); let mut any = false; for (namespace, id) in iter { if let Some(id) = id { redis_args( &mut cmd, [format!("{}_{}:{}", self.meta_namespace, namespace, id)] .as_slice(), ); any = true; } } if any { redis_execute::<()>(&mut cmd, &mut self.connection).await?; } Ok(()) } } #[derive(Serialize, Deserialize)] pub struct RedisValue { key: K, #[serde(skip_serializing_if = "Option::is_none")] alias: Option, iat: i64, val: T, } pub fn redis_args(cmd: &mut Cmd, args: &[String]) { for arg in args { cmd.arg(arg); } } pub async fn redis_execute( cmd: &mut Cmd, redis: &mut deadpool_redis::Connection, ) -> Result where T: redis::FromRedisValue, { let res = cmd.query_async::(redis).await?; Ok(res) }