Files
AstralRinth/packages/app-lib/src/util/fetch.rs
François-Xavier Talbot 79217e78b4 Theseus circuit breaker (#5196)
* Impl request fence

* Reduce max FETCH_ATTEMPTS

* Tweak fail threshold & window

* Check block status every retry iteration

* Fix fmt

* Fix fmt 2
2026-01-23 20:28:50 +00:00

579 lines
17 KiB
Rust

//! Functions for fetching information from the Internet
use super::io::{self, IOError};
use crate::ErrorKind;
use crate::event::LoadingBarId;
use crate::event::emit::emit_loading;
use bytes::Bytes;
use chrono::{DateTime, TimeDelta, Utc};
use parking_lot::Mutex;
use rand::Rng;
use reqwest::Method;
use serde::de::DeserializeOwned;
use std::collections::VecDeque;
use std::ffi::OsStr;
use std::path::{Path, PathBuf};
use std::sync::LazyLock;
use std::time::{self};
use tokio::sync::Semaphore;
use tokio::{fs::File, io::AsyncWriteExt};
#[derive(Debug)]
pub struct IoSemaphore(pub Semaphore);
#[derive(Debug)]
pub struct FetchSemaphore(pub Semaphore);
struct FetchFence {
inner: Mutex<FenceInner>,
}
impl FetchFence {
pub fn is_blocked(&self) -> bool {
self.inner.lock().is_blocked()
}
pub fn record_ok(&self) {
self.inner.lock().record_ok()
}
pub fn record_fail(&self) {
self.inner.lock().record_fail()
}
}
struct FenceInner {
failures: VecDeque<DateTime<Utc>>,
block_until: Option<DateTime<Utc>>,
block_factor: i32,
}
impl FenceInner {
const FAILURE_WINDOW: TimeDelta = TimeDelta::minutes(3);
const FAILURE_THRESHOLD: usize = 4;
const BLOCK_DURATION_MIN_BASE: TimeDelta = TimeDelta::minutes(2);
const BLOCK_DURATION_MAX_BASE: TimeDelta = TimeDelta::minutes(5);
const BLOCK_DURATION_MAX_FACTOR: i32 = 3;
pub fn new() -> Self {
Self {
failures: VecDeque::new(),
block_until: None,
block_factor: 0,
}
}
pub fn is_blocked(&mut self) -> bool {
if let Some(until) = self.block_until {
if until > Utc::now() {
return true;
} else {
self.block_until = None;
}
}
false
}
pub fn record_ok(&mut self) {
self.prune(Utc::now());
}
pub fn record_fail(&mut self) {
self.prune(Utc::now());
self.failures.push_back(Utc::now());
if self.failures.len() >= Self::FAILURE_THRESHOLD {
self.trigger_block();
}
}
/// Blocks further requests for a random duration between the min and max base durations, scaled by a factor
/// of how many blocks have been triggered in this session.
///
/// As such, for the first block, the duration will be between 2 and 5 minutes.
/// - For the second block, between 4 and 10 minutes.
/// - For the third block and any further blocks, between 6 and 15 minutes.
fn trigger_block(&mut self) {
self.block_factor =
i32::min(self.block_factor + 1, Self::BLOCK_DURATION_MAX_FACTOR);
let min = Self::BLOCK_DURATION_MIN_BASE
.checked_mul(self.block_factor)
.unwrap_or(Self::BLOCK_DURATION_MIN_BASE);
let max = Self::BLOCK_DURATION_MAX_BASE
.checked_mul(self.block_factor)
.unwrap_or(Self::BLOCK_DURATION_MAX_BASE);
let delta_seconds = (max - min).as_seconds_f64()
* rand::thread_rng().gen_range(0.0..=1.0);
let duration =
min + TimeDelta::milliseconds((delta_seconds * 1000.0) as i64);
self.block_until = Some(Utc::now() + duration);
}
/// Removes all failure points older than the failure window
fn prune(&mut self, now: DateTime<Utc>) {
let cutoff = now - Self::FAILURE_WINDOW;
while let Some(&front) = self.failures.front() {
if front < cutoff {
self.failures.pop_front();
} else {
break;
}
}
}
}
static GLOBAL_FETCH_FENCE: LazyLock<FetchFence> =
LazyLock::new(|| FetchFence {
inner: Mutex::new(FenceInner::new()),
});
pub static REQWEST_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
let mut headers = reqwest::header::HeaderMap::new();
let header =
reqwest::header::HeaderValue::from_str(&crate::launcher_user_agent())
.unwrap();
headers.insert(reqwest::header::USER_AGENT, header);
reqwest::Client::builder()
.tcp_keepalive(Some(time::Duration::from_secs(10)))
.default_headers(headers)
.build()
.expect("Reqwest Client Building Failed")
});
const FETCH_ATTEMPTS: usize = 2;
#[tracing::instrument(skip(semaphore))]
pub async fn fetch(
url: &str,
sha1: Option<&str>,
semaphore: &FetchSemaphore,
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite>,
) -> crate::Result<Bytes> {
fetch_advanced(Method::GET, url, sha1, None, None, None, semaphore, exec)
.await
}
#[tracing::instrument(skip(json_body, semaphore))]
pub async fn fetch_json<T>(
method: Method,
url: &str,
sha1: Option<&str>,
json_body: Option<serde_json::Value>,
semaphore: &FetchSemaphore,
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite>,
) -> crate::Result<T>
where
T: DeserializeOwned,
{
let result = fetch_advanced(
method, url, sha1, json_body, None, None, semaphore, exec,
)
.await?;
let value = serde_json::from_slice(&result)?;
Ok(value)
}
/// Downloads a file with retry and checksum functionality
#[tracing::instrument(skip(json_body, semaphore))]
#[allow(clippy::too_many_arguments)]
pub async fn fetch_advanced(
method: Method,
url: &str,
sha1: Option<&str>,
json_body: Option<serde_json::Value>,
header: Option<(&str, &str)>,
loading_bar: Option<(&LoadingBarId, f64)>,
semaphore: &FetchSemaphore,
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite>,
) -> crate::Result<Bytes> {
let _permit = semaphore.0.acquire().await?;
let is_api_url = url.starts_with(env!("MODRINTH_API_URL"))
|| url.starts_with(env!("MODRINTH_API_URL_V3"));
let creds = if header
.as_ref()
.is_none_or(|x| &*x.0.to_lowercase() != "authorization")
&& (url.starts_with("https://cdn.modrinth.com") || is_api_url)
{
crate::state::ModrinthCredentials::get_active(exec).await?
} else {
None
};
for attempt in 1..=(FETCH_ATTEMPTS + 1) {
if is_api_url && GLOBAL_FETCH_FENCE.is_blocked() {
return Err(ErrorKind::ApiIsDownError.into());
}
let mut req = REQWEST_CLIENT.request(method.clone(), url);
if let Some(body) = json_body.clone() {
req = req.json(&body);
}
if let Some(header) = header {
req = req.header(header.0, header.1);
}
if let Some(ref creds) = creds {
req = req.header("Authorization", &creds.session);
}
let result = req.send().await;
match result {
Ok(resp) => {
if resp.status().is_server_error() {
if is_api_url {
GLOBAL_FETCH_FENCE.record_fail();
}
if attempt <= FETCH_ATTEMPTS {
continue;
}
}
if resp.status().is_client_error()
|| resp.status().is_server_error()
{
let backup_error = resp.error_for_status_ref().unwrap_err();
if let Ok(error) = resp.json().await {
return Err(ErrorKind::LabrinthError(error).into());
}
return Err(backup_error.into());
}
let bytes = if let Some((bar, total)) = &loading_bar {
let length = resp.content_length();
if let Some(total_size) = length {
use futures::StreamExt;
let mut stream = resp.bytes_stream();
let mut bytes = Vec::new();
while let Some(item) = stream.next().await {
let chunk = item.or(Err(ErrorKind::NoValueFor(
"fetch bytes".to_string(),
)))?;
bytes.append(&mut chunk.to_vec());
emit_loading(
bar,
(chunk.len() as f64 / total_size as f64)
* total,
None,
)?;
}
Ok(bytes::Bytes::from(bytes))
} else {
resp.bytes().await
}
} else {
resp.bytes().await
};
if let Ok(bytes) = bytes {
if let Some(sha1) = sha1 {
let hash = sha1_async(bytes.clone()).await?;
if &*hash != sha1 {
if attempt <= FETCH_ATTEMPTS {
continue;
} else {
return Err(ErrorKind::HashError(
sha1.to_string(),
hash,
)
.into());
}
}
}
tracing::trace!("Done downloading URL {url}");
if is_api_url {
GLOBAL_FETCH_FENCE.record_ok();
}
return Ok(bytes);
} else if attempt <= FETCH_ATTEMPTS {
continue;
} else if let Err(err) = bytes {
return Err(err.into());
}
}
Err(_) if attempt <= FETCH_ATTEMPTS => continue,
Err(err) => {
return Err(err.into());
}
}
}
unreachable!()
}
/// Downloads a file from specified mirrors
#[tracing::instrument(skip(semaphore))]
pub async fn fetch_mirrors(
mirrors: &[&str],
sha1: Option<&str>,
semaphore: &FetchSemaphore,
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite> + Copy,
) -> crate::Result<Bytes> {
if mirrors.is_empty() {
return Err(
ErrorKind::InputError("No mirrors provided!".to_string()).into()
);
}
for (index, mirror) in mirrors.iter().enumerate() {
let result = fetch(mirror, sha1, semaphore, exec).await;
if result.is_ok() || (result.is_err() && index == (mirrors.len() - 1)) {
return result;
}
}
unreachable!()
}
/// Posts a JSON to a URL
#[tracing::instrument(skip(json_body, semaphore))]
pub async fn post_json<T>(
url: &str,
json_body: serde_json::Value,
semaphore: &FetchSemaphore,
exec: impl sqlx::Executor<'_, Database = sqlx::Sqlite>,
) -> crate::Result<T>
where
T: DeserializeOwned,
{
let _permit = semaphore.0.acquire().await?;
let mut req = REQWEST_CLIENT.post(url).json(&json_body);
if let Some(creds) =
crate::state::ModrinthCredentials::get_active(exec).await?
{
req = req.header("Authorization", &creds.session);
}
let result = req.send().await?.error_for_status()?;
let value = result.json().await?;
Ok(value)
}
pub async fn read_json<T>(
path: &Path,
semaphore: &IoSemaphore,
) -> crate::Result<T>
where
T: DeserializeOwned,
{
let _permit = semaphore.0.acquire().await?;
let json = io::read(path).await?;
let json = serde_json::from_slice::<T>(&json)?;
Ok(json)
}
#[tracing::instrument(skip(bytes, semaphore))]
pub async fn write(
path: &Path,
bytes: &[u8],
semaphore: &IoSemaphore,
) -> crate::Result<()> {
let _permit = semaphore.0.acquire().await?;
if let Some(parent) = path.parent() {
io::create_dir_all(parent).await?;
}
let mut file = File::create(path)
.await
.map_err(|e| IOError::with_path(e, path))?;
file.write_all(bytes)
.await
.map_err(|e| IOError::with_path(e, path))?;
tracing::trace!("Done writing file {}", path.display());
Ok(())
}
pub async fn copy(
src: impl AsRef<Path>,
dest: impl AsRef<Path>,
semaphore: &IoSemaphore,
) -> crate::Result<()> {
let src: &Path = src.as_ref();
let dest = dest.as_ref();
let _permit = semaphore.0.acquire().await?;
if let Some(parent) = dest.parent() {
io::create_dir_all(parent).await?;
}
io::copy(src, dest).await?;
tracing::trace!(
"Done copying file {} to {}",
src.display(),
dest.display()
);
Ok(())
}
// Writes a icon to the cache and returns the absolute path of the icon within the cache directory
#[tracing::instrument(skip(bytes, semaphore))]
pub async fn write_cached_icon(
icon_path: &str,
cache_dir: &Path,
bytes: Bytes,
semaphore: &IoSemaphore,
) -> crate::Result<PathBuf> {
let extension = Path::new(&icon_path).extension().and_then(OsStr::to_str);
let hash = sha1_async(bytes.clone()).await?;
let path = cache_dir.join("icons").join(if let Some(ext) = extension {
format!("{hash}.{ext}")
} else {
hash
});
write(&path, &bytes, semaphore).await?;
let path = io::canonicalize(path)?;
Ok(path)
}
pub async fn sha1_async(bytes: Bytes) -> crate::Result<String> {
let hash = tokio::task::spawn_blocking(move || {
sha1_smol::Sha1::from(bytes).hexdigest()
})
.await?;
Ok(hash)
}
#[cfg(test)]
mod tests {
use super::*;
use chrono::{TimeDelta, Utc};
#[test]
fn test_fence_block_after_4_fails() {
// Update tests if the FenceInner constants change
let mut fence = FenceInner::new();
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(fence.is_blocked());
}
#[test]
fn test_fence_block_after_4_fails_with_oks() {
// Update tests if the FenceInner constants change
let mut fence = FenceInner::new();
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_ok();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(fence.is_blocked());
}
#[test]
fn test_fence_not_blocked_after_fails_expire() {
// Update tests if the FenceInner constants change
let mut fence = FenceInner::new();
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.prune(Utc::now() + TimeDelta::seconds(60 * 3 + 55)); // Should prune all failures
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(!fence.is_blocked());
fence.record_fail();
assert!(fence.is_blocked());
}
#[test]
fn test_fence_trigger_block_windows() {
// brute force flukes
for i in 0..128 {
let mut fence = FenceInner::new();
fence.trigger_block();
assert!(fence.is_blocked(), "Should be blocked (attempt {i})");
let block_until = fence.block_until.unwrap();
assert!(
block_until > Utc::now() + TimeDelta::seconds(60 + 55),
"Should be more than 2 minutes (with some leeway) (attempt {i})"
); // more than 2 minutes (with some leeway)
assert!(
block_until < Utc::now() + TimeDelta::seconds(60 * 5),
"Should be less than 5 minutes (attempt {i})"
); // less than 5 minutes
fence.block_until = None;
fence.trigger_block();
let block_until = fence.block_until.unwrap();
assert!(
block_until > Utc::now() + TimeDelta::seconds(60 * 3 + 55),
"Should be more than 4 minutes (with some leeway) (attempt {i})"
); // more than 4 minutes (with some leeway)
assert!(
block_until < Utc::now() + TimeDelta::seconds(60 * 10),
"Should be less than 10 minutes (attempt {i})"
); // less than 10 minutes
fence.block_until = None;
fence.trigger_block();
let block_until = fence.block_until.unwrap();
assert!(
block_until > Utc::now() + TimeDelta::seconds(60 * 5 + 55),
"Should be more than 6 minutes (with some leeway) (attempt {i})"
); // more than 6 minutes (with some leeway)
assert!(
block_until < Utc::now() + TimeDelta::seconds(60 * 15),
"Should be less than 15 minutes (attempt {i})"
); // less than 15 minutes
}
}
}