diff --git a/Cargo.lock b/Cargo.lock index 4fb3b0b0..3cea3b20 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5082,10 +5082,13 @@ dependencies = [ "derive_more 2.0.1", "dotenvy", "eyre", + "rust_decimal", "serde", + "serde_json", "tracing", "tracing-ecs", "tracing-subscriber", + "utoipa", ] [[package]] diff --git a/apps/labrinth/Cargo.toml b/apps/labrinth/Cargo.toml index 227a8b23..6920d146 100644 --- a/apps/labrinth/Cargo.toml +++ b/apps/labrinth/Cargo.toml @@ -71,8 +71,8 @@ json-patch = { workspace = true } lettre = { workspace = true } meilisearch-sdk = { workspace = true, features = ["reqwest"] } modrinth-maxmind = { workspace = true } -modrinth-util = { workspace = true } -muralpay = { workspace = true, features = ["utoipa", "mock"] } +modrinth-util = { workspace = true, features = ["decimal", "utoipa"] } +muralpay = { workspace = true, features = ["mock", "utoipa"] } murmur2 = { workspace = true } paste = { workspace = true } path-util = { workspace = true } diff --git a/apps/labrinth/src/models/v3/payouts.rs b/apps/labrinth/src/models/v3/payouts.rs index d56edb18..7e2c2428 100644 --- a/apps/labrinth/src/models/v3/payouts.rs +++ b/apps/labrinth/src/models/v3/payouts.rs @@ -252,9 +252,9 @@ pub struct PayoutMethodFee { } impl PayoutMethodFee { - pub fn compute_fee(&self, value: Decimal) -> Decimal { + pub fn compute_fee(&self, value: impl Into) -> Decimal { cmp::min( - cmp::max(self.min, self.percentage * value), + cmp::max(self.min, self.percentage * value.into()), self.max.unwrap_or(Decimal::MAX), ) } diff --git a/apps/labrinth/src/queue/payouts/mod.rs b/apps/labrinth/src/queue/payouts/mod.rs index 5fd3fb75..6beba756 100644 --- a/apps/labrinth/src/queue/payouts/mod.rs +++ b/apps/labrinth/src/queue/payouts/mod.rs @@ -20,10 +20,11 @@ use chrono::{DateTime, Datelike, Duration, NaiveTime, TimeZone, Utc}; use dashmap::DashMap; use eyre::{Result, eyre}; use futures::TryStreamExt; +use modrinth_util::decimal::Decimal2dp; use muralpay::MuralPay; use reqwest::Method; use rust_decimal::prelude::ToPrimitive; -use rust_decimal::{Decimal, dec}; +use rust_decimal::{Decimal, RoundingStrategy, dec}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -618,7 +619,7 @@ impl PayoutsQueue { &self, request: &PayoutMethodRequest, method_id: &str, - amount: Decimal, + amount: Decimal2dp, ) -> Result { const MURAL_FEE: Decimal = dec!(0.01); @@ -641,8 +642,9 @@ impl PayoutsQueue { .. }, } => PayoutFees { - method_fee: dec!(0), - platform_fee: amount * MURAL_FEE, + method_fee: Decimal2dp::ZERO, + platform_fee: amount + .mul_round(MURAL_FEE, RoundingStrategy::AwayFromZero), exchange_rate: None, }, PayoutMethodRequest::MuralPay { @@ -667,8 +669,14 @@ impl PayoutsQueue { fee_total, .. } => PayoutFees { - method_fee: fee_total.token_amount, - platform_fee: amount * MURAL_FEE, + method_fee: Decimal2dp::rounded( + fee_total.token_amount, + RoundingStrategy::AwayFromZero, + ), + platform_fee: amount.mul_round( + MURAL_FEE, + RoundingStrategy::AwayFromZero, + ), exchange_rate: Some(exchange_rate), }, muralpay::TokenPayoutFee::Error { message, .. } => { @@ -680,16 +688,22 @@ impl PayoutsQueue { } PayoutMethodRequest::PayPal | PayoutMethodRequest::Venmo => { let method = get_method.await?; - let fee = method.fee.compute_fee(amount); + let fee = Decimal2dp::rounded( + method.fee.compute_fee(amount), + RoundingStrategy::AwayFromZero, + ); PayoutFees { method_fee: fee, - platform_fee: dec!(0), + platform_fee: Decimal2dp::ZERO, exchange_rate: None, } } PayoutMethodRequest::Tremendous { method_details } => { let method = get_method.await?; - let fee = method.fee.compute_fee(amount); + let fee = Decimal2dp::rounded( + method.fee.compute_fee(amount), + RoundingStrategy::AwayFromZero, + ); let forex: TremendousForexResponse = self .make_tremendous_request(Method::GET, "forex", None::<()>) @@ -718,7 +732,7 @@ impl PayoutsQueue { // we send the request to Tremendous. Afterwards, the method // (Tremendous) will take 0% off the top of our $10. PayoutFees { - method_fee: dec!(0), + method_fee: Decimal2dp::ZERO, platform_fee: fee, exchange_rate, } @@ -736,19 +750,19 @@ pub struct PayoutFees { /// For example, if a user withdraws $10.00 and the method takes a /// 10% cut, then we submit a payout request of $10.00 to the method, /// and only $9.00 will be sent to the recipient. - pub method_fee: Decimal, + pub method_fee: Decimal2dp, /// Fee which we keep and don't pass to the underlying method. /// /// For example, if a user withdraws $10.00 and the method takes a /// 10% cut, then we submit a payout request of $9.00, and the $1.00 stays /// in our account. - pub platform_fee: Decimal, + pub platform_fee: Decimal2dp, /// How much is 1 USD worth in the target currency? pub exchange_rate: Option, } impl PayoutFees { - pub fn total_fee(&self) -> Decimal { + pub fn total_fee(&self) -> Decimal2dp { self.method_fee + self.platform_fee } } diff --git a/apps/labrinth/src/queue/payouts/mural.rs b/apps/labrinth/src/queue/payouts/mural.rs index 56552aa7..01fe6ca4 100644 --- a/apps/labrinth/src/queue/payouts/mural.rs +++ b/apps/labrinth/src/queue/payouts/mural.rs @@ -2,6 +2,7 @@ use ariadne::ids::UserId; use chrono::Utc; use eyre::{Result, eyre}; use futures::{StreamExt, TryFutureExt, stream::FuturesUnordered}; +use modrinth_util::decimal::Decimal2dp; use muralpay::{MuralError, MuralPay, TokenFeeRequest}; use rust_decimal::{Decimal, prelude::ToPrimitive}; use serde::{Deserialize, Serialize}; @@ -35,7 +36,7 @@ pub enum MuralPayoutRequest { impl PayoutsQueue { pub async fn compute_muralpay_fees( &self, - amount: Decimal, + amount: Decimal2dp, fiat_and_rail_code: muralpay::FiatAndRailCode, ) -> Result { let muralpay = self.muralpay.load(); @@ -48,7 +49,7 @@ impl PayoutsQueue { .get_fees_for_token_amount(&[TokenFeeRequest { amount: muralpay::TokenAmount { token_symbol: muralpay::USDC.into(), - token_amount: amount, + token_amount: amount.get(), }, fiat_and_rail_code, }]) @@ -65,7 +66,7 @@ impl PayoutsQueue { &self, payout_id: DBPayoutId, user_id: UserId, - gross_amount: Decimal, + gross_amount: Decimal2dp, fees: PayoutFees, payout_details: MuralPayoutRequest, recipient_info: muralpay::PayoutRecipientInfo, @@ -107,9 +108,9 @@ impl PayoutsQueue { let recipient_address = recipient_info.physical_address(); let recipient_email = recipient_info.email().to_string(); - let gross_amount_cents = gross_amount * Decimal::from(100); - let net_amount_cents = net_amount * Decimal::from(100); - let fees_cents = fees.total_fee() * Decimal::from(100); + let gross_amount_cents = gross_amount.get() * Decimal::from(100); + let net_amount_cents = net_amount.get() * Decimal::from(100); + let fees_cents = fees.total_fee().get() * Decimal::from(100); let address_line_3 = format!( "{}, {}, {}", recipient_address.city, @@ -153,7 +154,7 @@ impl PayoutsQueue { let payout = muralpay::CreatePayout { amount: muralpay::TokenAmount { - token_amount: sent_to_method, + token_amount: sent_to_method.get(), token_symbol: muralpay::USDC.into(), }, payout_details, diff --git a/apps/labrinth/src/routes/v3/payouts.rs b/apps/labrinth/src/routes/v3/payouts.rs index 7cf4319e..cec3f129 100644 --- a/apps/labrinth/src/routes/v3/payouts.rs +++ b/apps/labrinth/src/routes/v3/payouts.rs @@ -22,8 +22,9 @@ use chrono::{DateTime, Duration, Utc}; use eyre::eyre; use hex::ToHex; use hmac::{Hmac, Mac}; +use modrinth_util::decimal::Decimal2dp; use reqwest::Method; -use rust_decimal::Decimal; +use rust_decimal::{Decimal, RoundingStrategy}; use serde::{Deserialize, Serialize}; use serde_json::json; use sha2::Sha256; @@ -423,8 +424,7 @@ pub async fn tremendous_webhook( #[derive(Debug, Serialize, Deserialize, utoipa::ToSchema)] pub struct Withdrawal { - #[serde(with = "rust_decimal::serde::float")] - amount: Decimal, + amount: Decimal2dp, #[serde(flatten)] method: PayoutMethodRequest, method_id: String, @@ -432,7 +432,7 @@ pub struct Withdrawal { #[derive(Debug, Serialize, Deserialize)] pub struct WithdrawalFees { - pub fee: Decimal, + pub fee: Decimal2dp, pub exchange_rate: Option, } @@ -583,7 +583,8 @@ pub async fn create_payout( let fees = payouts_queue .calculate_fees(&body.method, &body.method_id, body.amount) - .await?; + .await + .wrap_internal_err("failed to compute fees")?; // fees are a bit complicated here, since we have 2 types: // - method fees - this is what Tremendous, Mural, etc. will take from us @@ -595,14 +596,18 @@ pub async fn create_payout( // then we issue a payout request with `amount - platform fees` let amount_minus_fee = body.amount - fees.total_fee(); - if amount_minus_fee.round_dp(2) <= Decimal::ZERO { + if amount_minus_fee <= Decimal::ZERO { return Err(ApiError::InvalidInput( "You need to withdraw more to cover the fee!".to_string(), )); } - let sent_to_method = (body.amount - fees.platform_fee).round_dp(2); - assert!(sent_to_method > Decimal::ZERO); + let sent_to_method = body.amount - fees.platform_fee; + if sent_to_method <= Decimal::ZERO { + return Err(ApiError::InvalidInput( + "You need to withdraw more to cover the fee!".to_string(), + )); + } let payout_id = generate_payout_id(&mut transaction) .await @@ -653,13 +658,13 @@ struct PayoutContext<'a> { body: &'a Withdrawal, user: &'a DBUser, payout_id: DBPayoutId, - gross_amount: Decimal, + gross_amount: Decimal2dp, fees: PayoutFees, /// Set as the [`DBPayout::amount`] field. - amount_minus_fee: Decimal, + amount_minus_fee: Decimal2dp, /// Set as the [`DBPayout::fee`] field. - total_fee: Decimal, - sent_to_method: Decimal, + total_fee: Decimal2dp, + sent_to_method: Decimal2dp, payouts_queue: &'a PayoutsQueue, } @@ -721,7 +726,10 @@ async fn tremendous_payout( forex.forex.get(¤cy_code).wrap_internal_err_with(|| { eyre!("no Tremendous forex data for {currency}") })?; - (sent_to_method * *exchange_rate, Some(currency_code)) + ( + sent_to_method.mul_round(*exchange_rate, RoundingStrategy::ToZero), + Some(currency_code), + ) } else { (sent_to_method, None) }; @@ -770,8 +778,8 @@ async fn tremendous_payout( user_id: user.id, created: Utc::now(), status: PayoutStatus::InTransit, - amount: amount_minus_fee, - fee: Some(total_fee), + amount: amount_minus_fee.get(), + fee: Some(total_fee.get()), method: Some(PayoutMethodType::Tremendous), method_id: Some(body.method_id.clone()), method_address: Some(user_email.to_string()), @@ -825,8 +833,8 @@ async fn mural_pay_payout( // after the payout has been successfully executed, // we wait for Mural's confirmation that the funds have been delivered status: PayoutStatus::InTransit, - amount: amount_minus_fee, - fee: Some(total_fee), + amount: amount_minus_fee.get(), + fee: Some(total_fee.get()), method: Some(PayoutMethodType::MuralPay), method_id: Some(method_id), method_address: Some(user_email.to_string()), @@ -962,8 +970,8 @@ async fn paypal_payout( user_id: user.id, created: Utc::now(), status: PayoutStatus::InTransit, - amount: amount_minus_fee, - fee: Some(total_fee), + amount: amount_minus_fee.get(), + fee: Some(total_fee.get()), method: Some(body.method.method_type()), method_id: Some(body.method_id.clone()), method_address: Some(display_address.clone()), diff --git a/packages/modrinth-util/Cargo.toml b/packages/modrinth-util/Cargo.toml index a5af7bc7..1fa164f5 100644 --- a/packages/modrinth-util/Cargo.toml +++ b/packages/modrinth-util/Cargo.toml @@ -9,10 +9,19 @@ actix-web = { workspace = true } derive_more = { workspace = true, features = ["display", "error", "from"] } dotenvy = { workspace = true } eyre = { workspace = true } +rust_decimal = { workspace = true, features = ["macros"], optional = true } serde = { workspace = true, features = ["derive"] } tracing = { workspace = true } tracing-ecs = { workspace = true } tracing-subscriber = { workspace = true } +utoipa = { workspace = true, optional = true } + +[dev-dependencies] +serde_json = { workspace = true } + +[features] +decimal = ["dep:rust_decimal", "utoipa?/decimal"] +utoipa = ["dep:utoipa"] [lints] workspace = true diff --git a/packages/modrinth-util/src/decimal.rs b/packages/modrinth-util/src/decimal.rs new file mode 100644 index 00000000..ec5771b3 --- /dev/null +++ b/packages/modrinth-util/src/decimal.rs @@ -0,0 +1,226 @@ +use std::{ + cmp, + ops::{Add, Sub}, +}; + +use derive_more::{Deref, Display, Error}; +use rust_decimal::{Decimal, RoundingStrategy}; +use serde::{Deserialize, Serialize}; + +#[derive( + Debug, + Display, + Clone, + Copy, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + Deref, + Serialize, + Deserialize, +)] +#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] +#[serde(try_from = "Decimal")] +pub struct DecimalDp(Decimal); + +pub type Decimal2dp = DecimalDp<2>; + +#[derive(Debug, Display, Clone, Error)] +#[display("decimal is not rounded to {dp} decimal places")] +pub struct NotRounded { + pub dp: u32, +} + +impl DecimalDp { + pub const ZERO: Self = Self(Decimal::ZERO); + + pub fn rounded(v: Decimal, strategy: RoundingStrategy) -> Self { + Self(v.round_dp_with_strategy(DP, strategy)) + } + + pub fn new(v: Decimal) -> Result { + if v.round_dp(DP) == v { + Ok(Self(v)) + } else { + Err(NotRounded { dp: DP }) + } + } + + pub fn get(self) -> Decimal { + self.0 + } + + pub fn mul_round( + self, + other: impl Into, + strategy: RoundingStrategy, + ) -> Self { + Self::rounded(self.0 * other.into(), strategy) + } +} + +// conversion + +impl TryFrom for DecimalDp { + type Error = NotRounded; + + fn try_from(value: Decimal) -> Result { + Self::new(value) + } +} + +impl From> for Decimal { + fn from(value: DecimalDp) -> Self { + value.0 + } +} + +// ord + +impl PartialOrd for DecimalDp { + fn partial_cmp(&self, other: &Decimal) -> Option { + self.0.partial_cmp(other) + } +} + +impl PartialOrd> for Decimal { + fn partial_cmp(&self, other: &DecimalDp) -> Option { + self.partial_cmp(&other.0) + } +} + +// eq + +impl PartialEq for DecimalDp { + fn eq(&self, other: &Decimal) -> bool { + self.0.eq(other) + } +} + +impl PartialEq> for Decimal { + fn eq(&self, other: &DecimalDp) -> bool { + self.eq(&other.0) + } +} + +// add + +impl Add for DecimalDp { + type Output = Self; + + fn add(self, rhs: DecimalDp) -> Self::Output { + let v = self.0 + rhs.0; + debug_assert!(Self::new(v).is_ok()); + Self(v) + } +} + +impl Add for DecimalDp { + type Output = Decimal; + + fn add(self, rhs: Decimal) -> Self::Output { + self.0 + rhs + } +} + +impl Add> for Decimal { + type Output = Decimal; + + fn add(self, rhs: DecimalDp) -> Self::Output { + self + rhs.0 + } +} + +// sub + +impl Sub for DecimalDp { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + let v = self.0 - rhs.0; + debug_assert!(Self::new(v).is_ok()); + Self(v) + } +} + +impl Sub for DecimalDp { + type Output = Decimal; + + fn sub(self, rhs: Decimal) -> Self::Output { + self.0 - rhs + } +} + +impl Sub> for Decimal { + type Output = Decimal; + + fn sub(self, rhs: DecimalDp) -> Self::Output { + self - rhs.0 + } +} + +#[cfg(test)] +mod test { + use super::*; + use rust_decimal::dec; + + #[test] + fn new() { + Decimal2dp::new(dec!(1)).unwrap(); + Decimal2dp::new(dec!(1.0)).unwrap(); + Decimal2dp::new(dec!(1.1)).unwrap(); + Decimal2dp::new(dec!(1.01)).unwrap(); + Decimal2dp::new(dec!(1.00)).unwrap(); + Decimal2dp::new(dec!(1.000)).unwrap(); + Decimal2dp::new(dec!(1.001)).unwrap_err(); + } + + #[test] + fn rounded() { + assert_eq!( + dec!(1), + Decimal2dp::rounded(dec!(1), RoundingStrategy::ToZero) + ); + assert_eq!( + dec!(1), + Decimal2dp::rounded(dec!(1.001), RoundingStrategy::ToZero) + ); + assert_eq!( + dec!(1), + Decimal2dp::rounded(dec!(1.005), RoundingStrategy::ToZero) + ); + assert_eq!( + dec!(1), + Decimal2dp::rounded(dec!(1.009), RoundingStrategy::ToZero) + ); + assert_eq!( + dec!(1.01), + Decimal2dp::rounded(dec!(1.010), RoundingStrategy::ToZero) + ); + } + + #[test] + fn deserialize() { + serde_json::from_str::("1").unwrap(); + serde_json::from_str::("1.0").unwrap(); + serde_json::from_str::("1.00").unwrap(); + serde_json::from_str::("1.000").unwrap(); + serde_json::from_str::("1.001").unwrap_err(); + } + + #[test] + fn ops() { + assert_eq!( + Decimal2dp::new(dec!(1.23)).unwrap() + + Decimal2dp::new(dec!(0.27)).unwrap(), + dec!(1.50) + ); + assert_eq!( + Decimal2dp::new(dec!(1.23)).unwrap() + - Decimal2dp::new(dec!(0.23)).unwrap(), + dec!(1.00) + ); + } +} diff --git a/packages/modrinth-util/src/lib.rs b/packages/modrinth-util/src/lib.rs index bb0c8c76..e284f70c 100644 --- a/packages/modrinth-util/src/lib.rs +++ b/packages/modrinth-util/src/lib.rs @@ -3,6 +3,9 @@ mod error; pub mod log; +#[cfg(feature = "decimal")] +pub mod decimal; + pub use error::*; use eyre::{Result, eyre};