Bump version: secrets-mcp-0.5.1 tag already existed while crates had further changes. Made-with: Cursor
161 lines
5.3 KiB
Rust
161 lines
5.3 KiB
Rust
use std::num::NonZeroU32;
|
|
use std::sync::Arc;
|
|
use std::time::Duration;
|
|
|
|
use axum::{
|
|
extract::{Request, State},
|
|
http::{HeaderMap, HeaderValue, StatusCode},
|
|
middleware::Next,
|
|
response::{IntoResponse, Response},
|
|
};
|
|
use governor::{
|
|
Quota, RateLimiter,
|
|
clock::{Clock, DefaultClock},
|
|
state::{InMemoryState, NotKeyed, keyed::DashMapStateStore},
|
|
};
|
|
use serde_json::json;
|
|
|
|
use crate::client_ip;
|
|
|
|
/// Per-IP rate limiter (keyed by client IP string)
|
|
type IpRateLimiter = RateLimiter<String, DashMapStateStore<String>, DefaultClock>;
|
|
|
|
/// Global rate limiter (not keyed)
|
|
type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
|
|
|
|
/// Parse a u32 env value into NonZeroU32, logging a warning and falling back
|
|
/// to the default if the value is zero.
|
|
fn nz_or_log(value: u32, default: u32, name: &str) -> NonZeroU32 {
|
|
NonZeroU32::new(value).unwrap_or_else(|| {
|
|
tracing::warn!(
|
|
configured = value,
|
|
default,
|
|
"{name} must be non-zero, using default"
|
|
);
|
|
NonZeroU32::new(default).unwrap()
|
|
})
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct RateLimitState {
|
|
pub ip_limiter: Arc<IpRateLimiter>,
|
|
pub global_limiter: Arc<GlobalRateLimiter>,
|
|
}
|
|
|
|
impl RateLimitState {
|
|
/// Create a new RateLimitState with default limits.
|
|
///
|
|
/// Default limits (can be overridden via environment variables):
|
|
/// - Global: 100 req/s, burst 200
|
|
/// - Per-IP: 20 req/s, burst 40
|
|
pub fn new() -> Self {
|
|
let global_rate = std::env::var("RATE_LIMIT_GLOBAL_PER_SECOND")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u32>().ok())
|
|
.unwrap_or(100);
|
|
|
|
let global_burst = std::env::var("RATE_LIMIT_GLOBAL_BURST")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u32>().ok())
|
|
.unwrap_or(200);
|
|
|
|
let ip_rate = std::env::var("RATE_LIMIT_IP_PER_SECOND")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u32>().ok())
|
|
.unwrap_or(20);
|
|
|
|
let ip_burst = std::env::var("RATE_LIMIT_IP_BURST")
|
|
.ok()
|
|
.and_then(|v| v.parse::<u32>().ok())
|
|
.unwrap_or(40);
|
|
|
|
let global_rate_nz = nz_or_log(global_rate, 100, "RATE_LIMIT_GLOBAL_PER_SECOND");
|
|
let global_burst_nz = nz_or_log(global_burst, 200, "RATE_LIMIT_GLOBAL_BURST");
|
|
let ip_rate_nz = nz_or_log(ip_rate, 20, "RATE_LIMIT_IP_PER_SECOND");
|
|
let ip_burst_nz = nz_or_log(ip_burst, 40, "RATE_LIMIT_IP_BURST");
|
|
|
|
let global_quota = Quota::per_second(global_rate_nz).allow_burst(global_burst_nz);
|
|
let ip_quota = Quota::per_second(ip_rate_nz).allow_burst(ip_burst_nz);
|
|
|
|
tracing::info!(
|
|
global_rate = global_rate_nz.get(),
|
|
global_burst = global_burst_nz.get(),
|
|
ip_rate = ip_rate_nz.get(),
|
|
ip_burst = ip_burst_nz.get(),
|
|
"rate limiter initialized"
|
|
);
|
|
|
|
Self {
|
|
global_limiter: Arc::new(RateLimiter::direct(global_quota)),
|
|
ip_limiter: Arc::new(RateLimiter::dashmap(ip_quota)),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Rate limiting middleware function.
|
|
///
|
|
/// Checks both global and per-IP rate limits before allowing the request through.
|
|
/// Returns 429 Too Many Requests if either limit is exceeded.
|
|
pub async fn rate_limit_middleware(
|
|
State(rl): State<RateLimitState>,
|
|
req: Request,
|
|
next: Next,
|
|
) -> Result<Response, Response> {
|
|
// Check global rate limit first
|
|
if let Err(negative) = rl.global_limiter.check() {
|
|
let retry_after = negative.wait_time_from(DefaultClock::default().now());
|
|
tracing::warn!(
|
|
retry_after_secs = retry_after.as_secs(),
|
|
"global rate limit exceeded"
|
|
);
|
|
return Err(too_many_requests_response(Some(retry_after)));
|
|
}
|
|
|
|
// Check per-IP rate limit
|
|
let key = client_ip::extract_client_ip(&req);
|
|
if let Err(negative) = rl.ip_limiter.check_key(&key) {
|
|
let retry_after = negative.wait_time_from(DefaultClock::default().now());
|
|
tracing::warn!(
|
|
client_ip = %key,
|
|
retry_after_secs = retry_after.as_secs(),
|
|
"per-IP rate limit exceeded"
|
|
);
|
|
return Err(too_many_requests_response(Some(retry_after)));
|
|
}
|
|
|
|
Ok(next.run(req).await)
|
|
}
|
|
|
|
/// Start a background task to clean up expired rate limiter entries.
|
|
///
|
|
/// This should be called once during application startup.
|
|
/// The task runs every 60 seconds and will be aborted on shutdown.
|
|
pub fn spawn_cleanup_task(ip_limiter: Arc<IpRateLimiter>) -> tokio::task::JoinHandle<()> {
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
|
loop {
|
|
interval.tick().await;
|
|
ip_limiter.retain_recent();
|
|
}
|
|
})
|
|
}
|
|
|
|
/// Create a 429 Too Many Requests response.
|
|
fn too_many_requests_response(retry_after: Option<Duration>) -> Response {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
|
|
|
if let Some(duration) = retry_after {
|
|
let secs = duration.as_secs().max(1);
|
|
if let Ok(value) = HeaderValue::from_str(&secs.to_string()) {
|
|
headers.insert("Retry-After", value);
|
|
}
|
|
}
|
|
|
|
let body = json!({
|
|
"error": "Too many requests, please try again later"
|
|
});
|
|
|
|
(StatusCode::TOO_MANY_REQUESTS, headers, body.to_string()).into_response()
|
|
}
|