diff --git a/Cargo.lock b/Cargo.lock index 16f8f2661..6efc9b9d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1432,6 +1432,7 @@ dependencies = [ "sqlx", "thiserror", "time 0.2.27", + "tokio", "tokio-stream", "url", "urlencoding", @@ -2761,9 +2762,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c" [[package]] name = "tokio" -version = "1.18.2" +version = "1.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4903bf0427cf68dddd5aa6a93220756f8be0c34fcfa9f5e6191e103e15a31395" +checksum = "95eec79ea28c00a365f539f1961e9278fbcaf81c0ff6aaf0e93c181352446948" dependencies = [ "bytes", "libc", diff --git a/Cargo.toml b/Cargo.toml index 518c08705..2891d4323 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ path = "src/main.rs" actix = "0.13.0" actix-web = { git = "https://github.com/modrinth/actix-web", rev = "88c7c18" } actix-rt = "2.7.0" +tokio = { version = "1.19.0", features = ["sync"] } tokio-stream = "0.1.8" actix-multipart = { git = "https://github.com/modrinth/actix-web", rev = "88c7c18" } actix-cors = { git = "https://github.com/modrinth/actix-extras.git", rev = "34d301f" } diff --git a/src/database/models/ids.rs b/src/database/models/ids.rs index a5b2a6f3c..b0ad29dff 100644 --- a/src/database/models/ids.rs +++ b/src/database/models/ids.rs @@ -133,7 +133,7 @@ pub struct LicenseId(pub i32); #[sqlx(transparent)] pub struct DonationPlatformId(pub i32); -#[derive(Copy, Clone, Debug, Type, PartialEq)] +#[derive(Copy, Clone, Debug, Type, PartialEq, Eq, Hash)] #[sqlx(transparent)] pub struct VersionId(pub i64); #[derive(Copy, Clone, Debug, Type)] diff --git a/src/routes/version_file.rs b/src/routes/version_file.rs index af27ccc9c..3eddab6a2 100644 --- a/src/routes/version_file.rs +++ b/src/routes/version_file.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::collections::HashMap; use std::sync::Arc; +use tokio::sync::RwLock; #[derive(Deserialize)] pub struct Algorithm { @@ -428,9 +429,10 @@ pub async fn update_files( .fetch_all(&mut *transaction) .await?; - let mut version_ids = Vec::new(); + let version_ids: RwLock>> = + RwLock::new(HashMap::new()); - for row in &result { + futures::future::try_join_all(result.into_iter().map(|row| async { let updated_versions = database::models::Version::get_project_versions( database::models::ProjectId(row.project_id), Some( @@ -454,28 +456,40 @@ pub async fn update_files( .await?; if let Some(latest_version) = updated_versions.last() { - version_ids.push(*latest_version); - } - } + let mut version_ids = version_ids.write().await; - let versions = - database::models::Version::get_many_full(version_ids, &**pool).await?; + version_ids.insert(*latest_version, row.hash); + } + + Ok::<(), ApiError>(()) + })) + .await?; + + let version_ids = version_ids.into_inner(); + + let versions = database::models::Version::get_many_full( + version_ids.keys().copied().collect(), + &**pool, + ) + .await?; let mut response = HashMap::new(); - for row in &result { - if let Some(version) = - versions.iter().find(|x| x.id.0 == row.version_id) - { - if let Ok(parsed_hash) = String::from_utf8(row.hash.clone()) { + for version in versions { + let hash = version_ids.get(&version.id); + + if let Some(hash) = hash { + if let Ok(parsed_hash) = String::from_utf8(hash.clone()) { response.insert( parsed_hash, - models::projects::Version::from(version.clone()), + models::projects::Version::from(version), ); } else { + let version_id: models::projects::VersionId = version.id.into(); + return Err(ApiError::Database(DatabaseError::Other(format!( "Could not parse hash for version {}", - row.version_id + version_id )))); } }