use crate::ratelimit::errors::ARError; use crate::ratelimit::{ActorMessage, ActorResponse}; use actix::dev::*; use actix_web::{ dev::{Service, ServiceRequest, ServiceResponse, Transform}, error::Error as AWError, http::header::{HeaderName, HeaderValue}, }; use futures::future::{ok, Ready}; use log::*; use std::{ cell::RefCell, future::Future, ops::Fn, pin::Pin, rc::Rc, task::{Context, Poll}, time::Duration, }; type RateLimiterIdentifier = Rc Result + 'static>>; pub struct RateLimiter where T: Handler + Send + Sync + 'static, T::Context: ToEnvelope, { interval: Duration, max_requests: usize, store: Addr, identifier: RateLimiterIdentifier, ignore_key: Option, } impl RateLimiter where T: Handler + Send + Sync + 'static, ::Context: ToEnvelope, { /// Creates a new instance of `RateLimiter` with the provided address of `StoreActor`. pub fn new(store: Addr) -> Self { let identifier = |req: &ServiceRequest| { let connection_info = req.connection_info(); let ip = connection_info.peer_addr().ok_or(ARError::Identification)?; Ok(String::from(ip)) }; RateLimiter { interval: Duration::from_secs(0), max_requests: 0, store, identifier: Rc::new(Box::new(identifier)), ignore_key: None, } } /// Specify the interval. The counter for a client is reset after this interval pub fn with_interval(mut self, interval: Duration) -> Self { self.interval = interval; self } /// Specify the maximum number of requests allowed in the given interval. pub fn with_max_requests(mut self, max_requests: usize) -> Self { self.max_requests = max_requests; self } /// Sets key which can be used to bypass rate-limiter pub fn with_ignore_key(mut self, ignore_key: Option) -> Self { self.ignore_key = ignore_key; self } /// Function to get the identifier for the client request pub fn with_identifier< F: Fn(&ServiceRequest) -> Result + 'static, >( mut self, identifier: F, ) -> Self { self.identifier = Rc::new(Box::new(identifier)); self } } impl Transform for RateLimiter where T: Handler + Send + Sync + 'static, T::Context: ToEnvelope, S: Service, Error = AWError> + 'static, S::Future: 'static, B: 'static, { type Response = ServiceResponse; type Error = S::Error; type Transform = RateLimitMiddleware; type InitError = (); type Future = Ready>; fn new_transform(&self, service: S) -> Self::Future { ok(RateLimitMiddleware { service: Rc::new(RefCell::new(service)), store: self.store.clone(), max_requests: self.max_requests, interval: self.interval.as_secs(), identifier: self.identifier.clone(), ignore_key: self.ignore_key.clone(), }) } } /// Service factory for RateLimiter pub struct RateLimitMiddleware where S: 'static, T: Handler + 'static, { service: Rc>, store: Addr, // Exists here for the sole purpose of knowing the max_requests and interval from RateLimiter max_requests: usize, interval: u64, identifier: RateLimiterIdentifier, ignore_key: Option, } impl Service for RateLimitMiddleware where T: Handler + 'static, S: Service, Error = AWError> + 'static, S::Future: 'static, B: 'static, T::Context: ToEnvelope, { type Response = ServiceResponse; type Error = S::Error; type Future = Pin>>>; fn poll_ready( &self, cx: &mut Context<'_>, ) -> Poll> { self.service.borrow_mut().poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { let store = self.store.clone(); let srv = self.service.clone(); let max_requests = self.max_requests; let interval = Duration::from_secs(self.interval); let identifier = self.identifier.clone(); let ignore_key = self.ignore_key.clone(); Box::pin(async move { let identifier: String = (identifier)(&req)?; if let Some(ignore_key) = ignore_key { if let Some(key) = req.headers().get("x-ratelimit-key") { if key.to_str().ok().unwrap_or_default() == &*ignore_key { let fut = srv.call(req); let res = fut.await?; return Ok(res); } } } let remaining: ActorResponse = store .send(ActorMessage::Get(String::from(&identifier))) .await .map_err(|_| ARError::Identification)?; match remaining { ActorResponse::Get(opt) => { let opt = opt.await?; if let Some(c) = opt { // Existing entry in store let expiry = store .send(ActorMessage::Expire(String::from( &identifier, ))) .await .map_err(|_| { ARError::ReadWrite( "Setting timeout".to_string(), ) })?; let reset: Duration = match expiry { ActorResponse::Expire(dur) => dur.await?, _ => unreachable!(), }; if c == 0 { info!("Limit exceeded for client: {}", &identifier); Err(ARError::Limited { max_requests, remaining: c, reset: reset.as_secs(), } .into()) } else { // Decrement value let res: ActorResponse = store .send(ActorMessage::Update { key: identifier, value: 1, }) .await .map_err(|_| { ARError::ReadWrite( "Decrementing ratelimit".to_string(), ) })?; let updated_value: usize = match res { ActorResponse::Update(c) => c.await?, _ => unreachable!(), }; // Execute the request let fut = srv.call(req); let mut res = fut.await?; let headers = res.headers_mut(); // Safe unwraps, since usize is always convertible to string headers.insert( HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_str( max_requests.to_string().as_str(), )?, ); headers.insert( HeaderName::from_static( "x-ratelimit-remaining", ), HeaderValue::from_str( updated_value.to_string().as_str(), )?, ); headers.insert( HeaderName::from_static("x-ratelimit-reset"), HeaderValue::from_str( reset.as_secs().to_string().as_str(), )?, ); Ok(res) } } else { // New client, create entry in store let current_value = max_requests - 1; let res = store .send(ActorMessage::Set { key: String::from(&identifier), value: current_value, expiry: interval, }) .await .map_err(|_| { ARError::ReadWrite( "Creating store entry".to_string(), ) })?; match res { ActorResponse::Set(c) => c.await?, _ => unreachable!(), } let fut = srv.call(req); let mut res = fut.await?; let headers = res.headers_mut(); // Safe unwraps, since usize is always convertible to string headers.insert( HeaderName::from_static("x-ratelimit-limit"), HeaderValue::from_str( max_requests.to_string().as_str(), ) .unwrap(), ); headers.insert( HeaderName::from_static("x-ratelimit-remaining"), HeaderValue::from_str( current_value.to_string().as_str(), ) .unwrap(), ); headers.insert( HeaderName::from_static("x-ratelimit-reset"), HeaderValue::from_str( interval.as_secs().to_string().as_str(), ) .unwrap(), ); Ok(res) } } _ => { unreachable!(); } } }) } }