use std::{ cmp, ops::{Add, Sub}, }; use derive_more::{Deref, Display, Error}; use rust_decimal::{Decimal, RoundingStrategy}; use serde::{Deserialize, Serialize}; #[derive( Debug, Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Serialize, Deserialize, )] #[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))] #[serde(try_from = "Decimal")] pub struct DecimalDp(Decimal); pub type Decimal2dp = DecimalDp<2>; #[derive(Debug, Display, Clone, Error)] #[display("decimal is not rounded to {dp} decimal places")] pub struct NotRounded { pub dp: u32, } impl DecimalDp { pub const ZERO: Self = Self(Decimal::ZERO); pub fn rounded(v: Decimal, strategy: RoundingStrategy) -> Self { Self(v.round_dp_with_strategy(DP, strategy)) } pub fn new(v: Decimal) -> Result { if v.round_dp(DP) == v { Ok(Self(v)) } else { Err(NotRounded { dp: DP }) } } pub fn get(self) -> Decimal { self.0 } pub fn mul_round( self, other: impl Into, strategy: RoundingStrategy, ) -> Self { Self::rounded(self.0 * other.into(), strategy) } } // conversion impl TryFrom for DecimalDp { type Error = NotRounded; fn try_from(value: Decimal) -> Result { Self::new(value) } } impl From> for Decimal { fn from(value: DecimalDp) -> Self { value.0 } } // ord impl PartialOrd for DecimalDp { fn partial_cmp(&self, other: &Decimal) -> Option { self.0.partial_cmp(other) } } impl PartialOrd> for Decimal { fn partial_cmp(&self, other: &DecimalDp) -> Option { self.partial_cmp(&other.0) } } // eq impl PartialEq for DecimalDp { fn eq(&self, other: &Decimal) -> bool { self.0.eq(other) } } impl PartialEq> for Decimal { fn eq(&self, other: &DecimalDp) -> bool { self.eq(&other.0) } } // add impl Add for DecimalDp { type Output = Self; fn add(self, rhs: DecimalDp) -> Self::Output { let v = self.0 + rhs.0; debug_assert!(Self::new(v).is_ok()); Self(v) } } impl Add for DecimalDp { type Output = Decimal; fn add(self, rhs: Decimal) -> Self::Output { self.0 + rhs } } impl Add> for Decimal { type Output = Decimal; fn add(self, rhs: DecimalDp) -> Self::Output { self + rhs.0 } } // sub impl Sub for DecimalDp { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { let v = self.0 - rhs.0; debug_assert!(Self::new(v).is_ok()); Self(v) } } impl Sub for DecimalDp { type Output = Decimal; fn sub(self, rhs: Decimal) -> Self::Output { self.0 - rhs } } impl Sub> for Decimal { type Output = Decimal; fn sub(self, rhs: DecimalDp) -> Self::Output { self - rhs.0 } } #[cfg(test)] mod test { use super::*; use rust_decimal::dec; #[test] fn new() { Decimal2dp::new(dec!(1)).unwrap(); Decimal2dp::new(dec!(1.0)).unwrap(); Decimal2dp::new(dec!(1.1)).unwrap(); Decimal2dp::new(dec!(1.01)).unwrap(); Decimal2dp::new(dec!(1.00)).unwrap(); Decimal2dp::new(dec!(1.000)).unwrap(); Decimal2dp::new(dec!(1.001)).unwrap_err(); } #[test] fn rounded() { assert_eq!( dec!(1), Decimal2dp::rounded(dec!(1), RoundingStrategy::ToZero) ); assert_eq!( dec!(1), Decimal2dp::rounded(dec!(1.001), RoundingStrategy::ToZero) ); assert_eq!( dec!(1), Decimal2dp::rounded(dec!(1.005), RoundingStrategy::ToZero) ); assert_eq!( dec!(1), Decimal2dp::rounded(dec!(1.009), RoundingStrategy::ToZero) ); assert_eq!( dec!(1.01), Decimal2dp::rounded(dec!(1.010), RoundingStrategy::ToZero) ); } #[test] fn deserialize() { serde_json::from_str::("1").unwrap(); serde_json::from_str::("1.0").unwrap(); serde_json::from_str::("1.00").unwrap(); serde_json::from_str::("1.000").unwrap(); serde_json::from_str::("1.001").unwrap_err(); } #[test] fn ops() { assert_eq!( Decimal2dp::new(dec!(1.23)).unwrap() + Decimal2dp::new(dec!(0.27)).unwrap(), dec!(1.50) ); assert_eq!( Decimal2dp::new(dec!(1.23)).unwrap() - Decimal2dp::new(dec!(0.23)).unwrap(), dec!(1.00) ); } }