1
0

Shulkers of fixes (#327)

* Shulkers of fixes

* Fix validation message

* Update deps

* Bump docker image version
This commit is contained in:
Geometrically
2022-03-27 19:12:42 -07:00
committed by GitHub
parent 7415b07586
commit d1c0c9739d
42 changed files with 683 additions and 700 deletions

View File

@@ -1,4 +1,3 @@
//! RateLimiter middleware for actix application
use crate::ratelimit::errors::ARError;
use crate::ratelimit::{ActorMessage, ActorResponse};
use actix::dev::*;
@@ -19,28 +18,9 @@ use std::{
time::Duration,
};
/// Type that implements the ratelimit middleware.
///
/// This accepts _interval_ which specifies the
/// window size, _max_requests_ which specifies the maximum number of requests in that window, and
/// _store_ which is essentially a data store used to store client access information. Entry is removed from
/// the store after _interval_.
///
/// # Example
/// ```rust
/// # use std::time::Duration;
/// use actix_ratelimit::{MemoryStore, MemoryStoreActor};
/// use actix_ratelimit::RateLimiter;
///
/// #[actix_rt::main]
/// async fn main() {
/// let store = MemoryStore::new();
/// let ratelimiter = RateLimiter::new(
/// MemoryStoreActor::from(store.clone()).start())
/// .with_interval(Duration::from_secs(60))
/// .with_max_requests(100);
/// }
/// ```
type RateLimiterIdentifier =
Rc<Box<dyn Fn(&ServiceRequest) -> Result<String, ARError> + 'static>>;
pub struct RateLimiter<T>
where
T: Handler<ActorMessage> + Send + Sync + 'static,
@@ -49,7 +29,7 @@ where
interval: Duration,
max_requests: usize,
store: Addr<T>,
identifier: Rc<Box<dyn Fn(&ServiceRequest) -> Result<String, ARError>>>,
identifier: RateLimiterIdentifier,
ignore_ips: Vec<String>,
}
@@ -62,9 +42,8 @@ where
pub fn new(store: Addr<T>) -> Self {
let identifier = |req: &ServiceRequest| {
let connection_info = req.connection_info();
let ip = connection_info
.peer_addr()
.ok_or(ARError::IdentificationError)?;
let ip =
connection_info.peer_addr().ok_or(ARError::Identification)?;
Ok(String::from(ip))
};
RateLimiter {
@@ -144,8 +123,7 @@ where
// Exists here for the sole purpose of knowing the max_requests and interval from RateLimiter
max_requests: usize,
interval: u64,
identifier:
Rc<Box<dyn Fn(&ServiceRequest) -> Result<String, ARError> + 'static>>,
identifier: RateLimiterIdentifier,
ignore_ips: Vec<String>,
}
@@ -187,7 +165,7 @@ where
let remaining: ActorResponse = store
.send(ActorMessage::Get(String::from(&identifier)))
.await
.map_err(|_| ARError::IdentificationError)?;
.map_err(|_| ARError::Identification)?;
match remaining {
ActorResponse::Get(opt) => {
let opt = opt.await?;
@@ -199,7 +177,7 @@ where
)))
.await
.map_err(|_| {
ARError::ReadWriteError(
ARError::ReadWrite(
"Setting timeout".to_string(),
)
})?;
@@ -209,7 +187,7 @@ where
};
if c == 0 {
info!("Limit exceeded for client: {}", &identifier);
Err(ARError::LimitedError {
Err(ARError::Limited {
max_requests,
remaining: c,
reset: reset.as_secs(),
@@ -224,7 +202,7 @@ where
})
.await
.map_err(|_| {
ARError::ReadWriteError(
ARError::ReadWrite(
"Decrementing ratelimit".to_string(),
)
})?;
@@ -270,7 +248,7 @@ where
})
.await
.map_err(|_| {
ARError::ReadWriteError(
ARError::ReadWrite(
"Creating store entry".to_string(),
)
})?;