release: secrets-mcp 0.5.2
Bump version: secrets-mcp-0.5.1 tag already existed while crates had further changes. Made-with: Cursor
This commit is contained in:
160
crates/secrets-mcp/src/rate_limit.rs
Normal file
160
crates/secrets-mcp/src/rate_limit.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, HeaderValue, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use governor::{
|
||||
Quota, RateLimiter,
|
||||
clock::{Clock, DefaultClock},
|
||||
state::{InMemoryState, NotKeyed, keyed::DashMapStateStore},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::client_ip;
|
||||
|
||||
/// Per-IP rate limiter (keyed by client IP string)
|
||||
type IpRateLimiter = RateLimiter<String, DashMapStateStore<String>, DefaultClock>;
|
||||
|
||||
/// Global rate limiter (not keyed)
|
||||
type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>;
|
||||
|
||||
/// Parse a u32 env value into NonZeroU32, logging a warning and falling back
|
||||
/// to the default if the value is zero.
|
||||
fn nz_or_log(value: u32, default: u32, name: &str) -> NonZeroU32 {
|
||||
NonZeroU32::new(value).unwrap_or_else(|| {
|
||||
tracing::warn!(
|
||||
configured = value,
|
||||
default,
|
||||
"{name} must be non-zero, using default"
|
||||
);
|
||||
NonZeroU32::new(default).unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RateLimitState {
|
||||
pub ip_limiter: Arc<IpRateLimiter>,
|
||||
pub global_limiter: Arc<GlobalRateLimiter>,
|
||||
}
|
||||
|
||||
impl RateLimitState {
|
||||
/// Create a new RateLimitState with default limits.
|
||||
///
|
||||
/// Default limits (can be overridden via environment variables):
|
||||
/// - Global: 100 req/s, burst 200
|
||||
/// - Per-IP: 20 req/s, burst 40
|
||||
pub fn new() -> Self {
|
||||
let global_rate = std::env::var("RATE_LIMIT_GLOBAL_PER_SECOND")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(100);
|
||||
|
||||
let global_burst = std::env::var("RATE_LIMIT_GLOBAL_BURST")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(200);
|
||||
|
||||
let ip_rate = std::env::var("RATE_LIMIT_IP_PER_SECOND")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(20);
|
||||
|
||||
let ip_burst = std::env::var("RATE_LIMIT_IP_BURST")
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u32>().ok())
|
||||
.unwrap_or(40);
|
||||
|
||||
let global_rate_nz = nz_or_log(global_rate, 100, "RATE_LIMIT_GLOBAL_PER_SECOND");
|
||||
let global_burst_nz = nz_or_log(global_burst, 200, "RATE_LIMIT_GLOBAL_BURST");
|
||||
let ip_rate_nz = nz_or_log(ip_rate, 20, "RATE_LIMIT_IP_PER_SECOND");
|
||||
let ip_burst_nz = nz_or_log(ip_burst, 40, "RATE_LIMIT_IP_BURST");
|
||||
|
||||
let global_quota = Quota::per_second(global_rate_nz).allow_burst(global_burst_nz);
|
||||
let ip_quota = Quota::per_second(ip_rate_nz).allow_burst(ip_burst_nz);
|
||||
|
||||
tracing::info!(
|
||||
global_rate = global_rate_nz.get(),
|
||||
global_burst = global_burst_nz.get(),
|
||||
ip_rate = ip_rate_nz.get(),
|
||||
ip_burst = ip_burst_nz.get(),
|
||||
"rate limiter initialized"
|
||||
);
|
||||
|
||||
Self {
|
||||
global_limiter: Arc::new(RateLimiter::direct(global_quota)),
|
||||
ip_limiter: Arc::new(RateLimiter::dashmap(ip_quota)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting middleware function.
|
||||
///
|
||||
/// Checks both global and per-IP rate limits before allowing the request through.
|
||||
/// Returns 429 Too Many Requests if either limit is exceeded.
|
||||
pub async fn rate_limit_middleware(
|
||||
State(rl): State<RateLimitState>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, Response> {
|
||||
// Check global rate limit first
|
||||
if let Err(negative) = rl.global_limiter.check() {
|
||||
let retry_after = negative.wait_time_from(DefaultClock::default().now());
|
||||
tracing::warn!(
|
||||
retry_after_secs = retry_after.as_secs(),
|
||||
"global rate limit exceeded"
|
||||
);
|
||||
return Err(too_many_requests_response(Some(retry_after)));
|
||||
}
|
||||
|
||||
// Check per-IP rate limit
|
||||
let key = client_ip::extract_client_ip(&req);
|
||||
if let Err(negative) = rl.ip_limiter.check_key(&key) {
|
||||
let retry_after = negative.wait_time_from(DefaultClock::default().now());
|
||||
tracing::warn!(
|
||||
client_ip = %key,
|
||||
retry_after_secs = retry_after.as_secs(),
|
||||
"per-IP rate limit exceeded"
|
||||
);
|
||||
return Err(too_many_requests_response(Some(retry_after)));
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
|
||||
/// Start a background task to clean up expired rate limiter entries.
|
||||
///
|
||||
/// This should be called once during application startup.
|
||||
/// The task runs every 60 seconds and will be aborted on shutdown.
|
||||
pub fn spawn_cleanup_task(ip_limiter: Arc<IpRateLimiter>) -> tokio::task::JoinHandle<()> {
|
||||
tokio::spawn(async move {
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(60));
|
||||
loop {
|
||||
interval.tick().await;
|
||||
ip_limiter.retain_recent();
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a 429 Too Many Requests response.
|
||||
fn too_many_requests_response(retry_after: Option<Duration>) -> Response {
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("Content-Type", HeaderValue::from_static("application/json"));
|
||||
|
||||
if let Some(duration) = retry_after {
|
||||
let secs = duration.as_secs().max(1);
|
||||
if let Ok(value) = HeaderValue::from_str(&secs.to_string()) {
|
||||
headers.insert("Retry-After", value);
|
||||
}
|
||||
}
|
||||
|
||||
let body = json!({
|
||||
"error": "Too many requests, please try again later"
|
||||
});
|
||||
|
||||
(StatusCode::TOO_MANY_REQUESTS, headers, body.to_string()).into_response()
|
||||
}
|
||||
Reference in New Issue
Block a user