use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; use axum::{ extract::{Request, State}, http::{HeaderMap, HeaderValue, StatusCode}, middleware::Next, response::{IntoResponse, Response}, }; use governor::{ Quota, RateLimiter, clock::{Clock, DefaultClock}, state::{InMemoryState, NotKeyed, keyed::DashMapStateStore}, }; use serde_json::json; use crate::client_ip; /// Per-IP rate limiter (keyed by client IP string) type IpRateLimiter = RateLimiter, DefaultClock>; /// Global rate limiter (not keyed) type GlobalRateLimiter = RateLimiter; /// Parse a u32 env value into NonZeroU32, logging a warning and falling back /// to the default if the value is zero. fn nz_or_log(value: u32, default: u32, name: &str) -> NonZeroU32 { NonZeroU32::new(value).unwrap_or_else(|| { tracing::warn!( configured = value, default, "{name} must be non-zero, using default" ); NonZeroU32::new(default).unwrap() }) } #[derive(Clone)] pub struct RateLimitState { pub ip_limiter: Arc, pub global_limiter: Arc, } impl RateLimitState { /// Create a new RateLimitState with default limits. /// /// Default limits (can be overridden via environment variables): /// - Global: 100 req/s, burst 200 /// - Per-IP: 20 req/s, burst 40 pub fn new() -> Self { let global_rate = std::env::var("RATE_LIMIT_GLOBAL_PER_SECOND") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(100); let global_burst = std::env::var("RATE_LIMIT_GLOBAL_BURST") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(200); let ip_rate = std::env::var("RATE_LIMIT_IP_PER_SECOND") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(20); let ip_burst = std::env::var("RATE_LIMIT_IP_BURST") .ok() .and_then(|v| v.parse::().ok()) .unwrap_or(40); let global_rate_nz = nz_or_log(global_rate, 100, "RATE_LIMIT_GLOBAL_PER_SECOND"); let global_burst_nz = nz_or_log(global_burst, 200, "RATE_LIMIT_GLOBAL_BURST"); let ip_rate_nz = nz_or_log(ip_rate, 20, "RATE_LIMIT_IP_PER_SECOND"); let ip_burst_nz = nz_or_log(ip_burst, 40, "RATE_LIMIT_IP_BURST"); let global_quota = Quota::per_second(global_rate_nz).allow_burst(global_burst_nz); let ip_quota = Quota::per_second(ip_rate_nz).allow_burst(ip_burst_nz); tracing::info!( global_rate = global_rate_nz.get(), global_burst = global_burst_nz.get(), ip_rate = ip_rate_nz.get(), ip_burst = ip_burst_nz.get(), "rate limiter initialized" ); Self { global_limiter: Arc::new(RateLimiter::direct(global_quota)), ip_limiter: Arc::new(RateLimiter::dashmap(ip_quota)), } } } /// Rate limiting middleware function. /// /// Checks both global and per-IP rate limits before allowing the request through. /// Returns 429 Too Many Requests if either limit is exceeded. pub async fn rate_limit_middleware( State(rl): State, req: Request, next: Next, ) -> Result { // Check global rate limit first if let Err(negative) = rl.global_limiter.check() { let retry_after = negative.wait_time_from(DefaultClock::default().now()); tracing::warn!( retry_after_secs = retry_after.as_secs(), "global rate limit exceeded" ); return Err(too_many_requests_response(Some(retry_after))); } // Check per-IP rate limit let key = client_ip::extract_client_ip(&req); if let Err(negative) = rl.ip_limiter.check_key(&key) { let retry_after = negative.wait_time_from(DefaultClock::default().now()); tracing::warn!( client_ip = %key, retry_after_secs = retry_after.as_secs(), "per-IP rate limit exceeded" ); return Err(too_many_requests_response(Some(retry_after))); } Ok(next.run(req).await) } /// Start a background task to clean up expired rate limiter entries. /// /// This should be called once during application startup. /// The task runs every 60 seconds and will be aborted on shutdown. pub fn spawn_cleanup_task(ip_limiter: Arc) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut interval = tokio::time::interval(Duration::from_secs(60)); loop { interval.tick().await; ip_limiter.retain_recent(); } }) } /// Create a 429 Too Many Requests response. fn too_many_requests_response(retry_after: Option) -> Response { let mut headers = HeaderMap::new(); headers.insert("Content-Type", HeaderValue::from_static("application/json")); if let Some(duration) = retry_after { let secs = duration.as_secs().max(1); if let Ok(value) = HeaderValue::from_str(&secs.to_string()) { headers.insert("Retry-After", value); } } let body = json!({ "error": "Too many requests, please try again later" }); (StatusCode::TOO_MANY_REQUESTS, headers, body.to_string()).into_response() }