diff --git a/Cargo.lock b/Cargo.lock index 725461e..0f305f1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -464,6 +464,20 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "der" version = "0.7.10" @@ -596,6 +610,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -687,6 +707,12 @@ version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +[[package]] +name = "futures-timer" +version = "3.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f288b0a4f20f9a56b5d1da57e2227c661b7b16168e2f72365f57b63326e29b24" + [[package]] name = "futures-util" version = "0.3.32" @@ -765,6 +791,35 @@ dependencies = [ "polyval", ] +[[package]] +name = "governor" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9efcab3c1958580ff1f25a2a41be1668f7603d849bb63af523b208a3cc1223b8" +dependencies = [ + "cfg-if", + "dashmap", + "futures-sink", + "futures-timer", + "futures-util", + "getrandom 0.3.4", + "hashbrown 0.16.1", + "nonzero_ext", + "parking_lot", + "portable-atomic", + "quanta", + "rand 0.9.2", + "smallvec", + "spinning_top", + "web-time", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" + [[package]] name = "hashbrown" version = "0.15.5" @@ -773,7 +828,7 @@ checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" dependencies = [ "allocator-api2", "equivalent", - "foldhash", + "foldhash 0.1.5", ] [[package]] @@ -781,6 +836,11 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash 0.2.0", +] [[package]] name = "hashlink" @@ -1283,6 +1343,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nonzero_ext" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38bf9645c8b145698bb0b18a4637dcacbc421ea49bef2317e4fd8065a387cf21" + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -1463,6 +1529,12 @@ dependencies = [ "universal-hash", ] +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "potential_utf" version = "0.1.4" @@ -1506,6 +1578,21 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "quanta" +version = "0.12.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3ab5a9d756f0d97bdc89019bd2e4ea098cf9cde50ee7564dde6b81ccc8f06c7" +dependencies = [ + "crossbeam-utils", + "libc", + "once_cell", + "raw-cpuid", + "wasi", + "web-sys", + "winapi", +] + [[package]] name = "quinn" version = "0.11.9" @@ -1658,6 +1745,15 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" +[[package]] +name = "raw-cpuid" +version = "11.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -1969,7 +2065,7 @@ dependencies = [ [[package]] name = "secrets-mcp" -version = "0.5.1" +version = "0.5.2" dependencies = [ "anyhow", "askama", @@ -1977,6 +2073,7 @@ dependencies = [ "axum-extra", "chrono", "dotenvy", + "governor", "http", "rand 0.10.0", "reqwest", @@ -1995,6 +2092,7 @@ dependencies = [ "tower-sessions-sqlx-store-chrono", "tracing", "tracing-subscriber", + "url", "urlencoding", "uuid", ] @@ -2195,6 +2293,15 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spinning_top" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d96d2d1d716fb500937168cc09353ffdc7a012be8475ac7308e1bdf0e3923300" +dependencies = [ + "lock_api", +] + [[package]] name = "spki" version = "0.7.3" @@ -2717,6 +2824,7 @@ dependencies = [ "futures-util", "http", "http-body", + "http-body-util", "iri-string", "pin-project-lite", "tower", @@ -3167,6 +3275,28 @@ dependencies = [ "wasite", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" diff --git a/crates/secrets-core/src/db.rs b/crates/secrets-core/src/db.rs index 72304c3..8d92947 100644 --- a/crates/secrets-core/src/db.rs +++ b/crates/secrets-core/src/db.rs @@ -36,12 +36,31 @@ fn build_connect_options(config: &DatabaseConfig) -> Result { pub async fn create_pool(config: &DatabaseConfig) -> Result { tracing::debug!("connecting to database"); let connect_options = build_connect_options(config)?; + + // Connection pool configuration from environment + let max_connections = std::env::var("SECRETS_DATABASE_POOL_SIZE") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(10); + + let acquire_timeout_secs = std::env::var("SECRETS_DATABASE_ACQUIRE_TIMEOUT") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(5); + let pool = PgPoolOptions::new() - .max_connections(10) - .acquire_timeout(std::time::Duration::from_secs(5)) + .max_connections(max_connections) + .acquire_timeout(std::time::Duration::from_secs(acquire_timeout_secs)) + .max_lifetime(std::time::Duration::from_secs(1800)) // 30 minutes + .idle_timeout(std::time::Duration::from_secs(600)) // 10 minutes .connect_with(connect_options) .await?; - tracing::debug!("database connection established"); + + tracing::debug!( + max_connections, + acquire_timeout_secs, + "database connection established" + ); Ok(pool) } diff --git a/crates/secrets-core/src/error.rs b/crates/secrets-core/src/error.rs index 2c3d604..699e52a 100644 --- a/crates/secrets-core/src/error.rs +++ b/crates/secrets-core/src/error.rs @@ -15,6 +15,18 @@ pub enum AppError { #[error("Entry not found")] NotFoundEntry, + #[error("User not found")] + NotFoundUser, + + #[error("Secret not found")] + NotFoundSecret, + + #[error("Authentication failed")] + AuthenticationFailed, + + #[error("Unauthorized: insufficient permissions")] + Unauthorized, + #[error("Validation failed: {message}")] Validation { message: String }, @@ -24,6 +36,9 @@ pub enum AppError { #[error("Decryption failed — the encryption key may be incorrect")] DecryptionFailed, + #[error("Encryption key not set — user must set passphrase first")] + EncryptionKeyNotSet, + #[error(transparent)] Internal(#[from] anyhow::Error), } @@ -119,6 +134,18 @@ mod tests { let err = AppError::NotFoundEntry; assert_eq!(err.to_string(), "Entry not found"); + let err = AppError::NotFoundUser; + assert_eq!(err.to_string(), "User not found"); + + let err = AppError::NotFoundSecret; + assert_eq!(err.to_string(), "Secret not found"); + + let err = AppError::AuthenticationFailed; + assert_eq!(err.to_string(), "Authentication failed"); + + let err = AppError::Unauthorized; + assert!(err.to_string().contains("Unauthorized")); + let err = AppError::Validation { message: "too long".to_string(), }; @@ -126,6 +153,9 @@ mod tests { let err = AppError::ConcurrentModification; assert!(err.to_string().contains("Concurrent modification")); + + let err = AppError::EncryptionKeyNotSet; + assert!(err.to_string().contains("Encryption key not set")); } #[test] diff --git a/crates/secrets-core/src/service/api_key.rs b/crates/secrets-core/src/service/api_key.rs index 35e887f..8afaac4 100644 --- a/crates/secrets-core/src/service/api_key.rs +++ b/crates/secrets-core/src/service/api_key.rs @@ -2,6 +2,8 @@ use anyhow::Result; use sqlx::PgPool; use uuid::Uuid; +use crate::error::AppError; + const KEY_PREFIX: &str = "sk_"; /// Generate a new API key: `sk_<64 hex chars>` = 67 characters total. @@ -14,23 +16,32 @@ pub fn generate_api_key() -> String { } /// Return the user's existing API key, or generate and store a new one if NULL. +/// Uses a transaction with atomic update to prevent TOCTOU race conditions. pub async fn ensure_api_key(pool: &PgPool, user_id: Uuid) -> Result { - let existing: Option<(Option,)> = - sqlx::query_as("SELECT api_key FROM users WHERE id = $1") - .bind(user_id) - .fetch_optional(pool) - .await?; + let mut tx = pool.begin().await?; - if let Some((Some(key),)) = existing { + // Lock the row and check existing key + let existing: (Option,) = + sqlx::query_as("SELECT api_key FROM users WHERE id = $1 FOR UPDATE") + .bind(user_id) + .fetch_optional(&mut *tx) + .await? + .ok_or(AppError::NotFoundUser)?; + + if let Some(key) = existing.0 { + tx.commit().await?; return Ok(key); } + // Generate and store new key atomically let new_key = generate_api_key(); sqlx::query("UPDATE users SET api_key = $1 WHERE id = $2") .bind(&new_key) .bind(user_id) - .execute(pool) + .execute(&mut *tx) .await?; + + tx.commit().await?; Ok(new_key) } diff --git a/crates/secrets-core/src/service/user.rs b/crates/secrets-core/src/service/user.rs index 40dd443..b932a50 100644 --- a/crates/secrets-core/src/service/user.rs +++ b/crates/secrets-core/src/service/user.rs @@ -16,14 +16,17 @@ pub struct OAuthProfile { /// Find or create a user from an OAuth profile. /// Returns (user, is_new) where is_new indicates first-time registration. pub async fn find_or_create_user(pool: &PgPool, profile: OAuthProfile) -> Result<(User, bool)> { - // Check if this OAuth account already exists + // Use a transaction with FOR UPDATE to prevent TOCTOU race conditions + let mut tx = pool.begin().await?; + + // Check if this OAuth account already exists (with row lock) let existing: Option = sqlx::query_as( "SELECT id, user_id, provider, provider_id, email, name, avatar_url, created_at \ - FROM oauth_accounts WHERE provider = $1 AND provider_id = $2", + FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE", ) .bind(&profile.provider) .bind(&profile.provider_id) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await?; if let Some(oa) = existing { @@ -32,8 +35,9 @@ pub async fn find_or_create_user(pool: &PgPool, profile: OAuthProfile) -> Result FROM users WHERE id = $1", ) .bind(oa.user_id) - .fetch_one(pool) + .fetch_one(&mut *tx) .await?; + tx.commit().await?; return Ok((user, false)); } @@ -43,8 +47,6 @@ pub async fn find_or_create_user(pool: &PgPool, profile: OAuthProfile) -> Result .clone() .unwrap_or_else(|| profile.email.clone().unwrap_or_else(|| "User".to_string())); - let mut tx = pool.begin().await?; - let user: User = sqlx::query_as( "INSERT INTO users (email, name, avatar_url) \ VALUES ($1, $2, $3) \ @@ -125,13 +127,16 @@ pub async fn bind_oauth_account( user_id: Uuid, profile: OAuthProfile, ) -> Result { - // Check if this provider_id is already linked to someone else + // Use a transaction with FOR UPDATE to prevent TOCTOU race conditions + let mut tx = pool.begin().await?; + + // Check if this provider_id is already linked to someone else (with row lock) let conflict: Option<(Uuid,)> = sqlx::query_as( - "SELECT user_id FROM oauth_accounts WHERE provider = $1 AND provider_id = $2", + "SELECT user_id FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE", ) .bind(&profile.provider) .bind(&profile.provider_id) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await?; if let Some((existing_user_id,)) = conflict { @@ -148,11 +153,11 @@ pub async fn bind_oauth_account( } let existing_provider_for_user: Option<(String,)> = sqlx::query_as( - "SELECT provider_id FROM oauth_accounts WHERE user_id = $1 AND provider = $2", + "SELECT provider_id FROM oauth_accounts WHERE user_id = $1 AND provider = $2 FOR UPDATE", ) .bind(user_id) .bind(&profile.provider) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await?; if existing_provider_for_user.is_some() { @@ -174,9 +179,10 @@ pub async fn bind_oauth_account( .bind(&profile.email) .bind(&profile.name) .bind(&profile.avatar_url) - .fetch_one(pool) + .fetch_one(&mut *tx) .await?; + tx.commit().await?; Ok(account) } diff --git a/crates/secrets-mcp/Cargo.toml b/crates/secrets-mcp/Cargo.toml index 4e471ec..323f610 100644 --- a/crates/secrets-mcp/Cargo.toml +++ b/crates/secrets-mcp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "secrets-mcp" -version = "0.5.1" +version = "0.5.2" edition.workspace = true [[bin]] @@ -17,9 +17,10 @@ rmcp = { version = "1", features = ["server", "macros", "transport-streamable-ht axum = "0.8" axum-extra = { version = "0.10", features = ["typed-header"] } tower = "0.5" -tower-http = { version = "0.6", features = ["cors", "trace"] } +tower-http = { version = "0.6", features = ["cors", "trace", "limit"] } tower-sessions = "0.14" tower-sessions-sqlx-store-chrono = { version = "0.14", features = ["postgres"] } +governor = { version = "0.10", features = ["std", "jitter"] } time = "0.3" # OAuth (manual token exchange via reqwest) @@ -44,3 +45,4 @@ dotenvy.workspace = true urlencoding = "2" schemars = "1" http = "1" +url = "2" diff --git a/crates/secrets-mcp/src/auth.rs b/crates/secrets-mcp/src/auth.rs index 304f05e..22fa3ec 100644 --- a/crates/secrets-mcp/src/auth.rs +++ b/crates/secrets-mcp/src/auth.rs @@ -1,7 +1,5 @@ -use std::net::SocketAddr; - use axum::{ - extract::{ConnectInfo, Request, State}, + extract::{Request, State}, http::StatusCode, middleware::Next, response::Response, @@ -11,29 +9,14 @@ use uuid::Uuid; use secrets_core::service::api_key::validate_api_key; +use crate::client_ip; + /// Injected into request extensions after Bearer token validation. #[derive(Clone, Debug)] pub struct AuthUser { pub user_id: Uuid, } -fn log_client_ip(req: &Request) -> Option { - if let Some(first) = req - .headers() - .get("x-forwarded-for") - .and_then(|v| v.to_str().ok()) - .and_then(|s| s.split(',').next()) - { - let s = first.trim(); - if !s.is_empty() { - return Some(s.to_string()); - } - } - req.extensions() - .get::>() - .map(|c| c.ip().to_string()) -} - /// Axum middleware that validates Bearer API keys for the /mcp route. /// Passes all non-MCP paths through without authentication. pub async fn bearer_auth_middleware( @@ -43,7 +26,7 @@ pub async fn bearer_auth_middleware( ) -> Result { let path = req.uri().path(); let method = req.method().as_str(); - let client_ip = log_client_ip(&req); + let client_ip = client_ip::extract_client_ip(&req); // Only authenticate /mcp paths if !path.starts_with("/mcp") { @@ -66,7 +49,7 @@ pub async fn bearer_auth_middleware( tracing::warn!( method, path, - client_ip = client_ip.as_deref(), + %client_ip, "invalid Authorization header format on /mcp (expected Bearer …)" ); return Err(StatusCode::UNAUTHORIZED); @@ -75,7 +58,7 @@ pub async fn bearer_auth_middleware( tracing::warn!( method, path, - client_ip = client_ip.as_deref(), + %client_ip, "missing Authorization header on /mcp" ); return Err(StatusCode::UNAUTHORIZED); @@ -93,7 +76,7 @@ pub async fn bearer_auth_middleware( tracing::warn!( method, path, - client_ip = client_ip.as_deref(), + %client_ip, key_prefix = %&raw_key.chars().take(12).collect::(), key_len = raw_key.len(), "invalid api key (not found in database — e.g. revoked key or DB was reset; update MCP client Bearer token)" @@ -104,7 +87,7 @@ pub async fn bearer_auth_middleware( tracing::error!( method, path, - client_ip = client_ip.as_deref(), + %client_ip, error = %e, "api key validation error" ); diff --git a/crates/secrets-mcp/src/client_ip.rs b/crates/secrets-mcp/src/client_ip.rs new file mode 100644 index 0000000..ae317cb --- /dev/null +++ b/crates/secrets-mcp/src/client_ip.rs @@ -0,0 +1,65 @@ +use axum::extract::Request; +use std::net::{IpAddr, SocketAddr}; + +/// Extract the client IP from a request. +/// +/// When the `TRUST_PROXY` environment variable is set to `1` or `true`, the +/// `X-Forwarded-For` and `X-Real-IP` headers are consulted first, which is +/// appropriate when the service runs behind a trusted reverse proxy (e.g. +/// Caddy). Otherwise — or if those headers are absent/empty — the direct TCP +/// connection address from `ConnectInfo` is used. +/// +/// **Important**: only enable `TRUST_PROXY` when the application is guaranteed +/// to receive traffic exclusively through a controlled reverse proxy. Enabling +/// it on a directly-exposed port allows clients to spoof their IP address and +/// bypass per-IP rate limiting. +pub fn extract_client_ip(req: &Request) -> String { + if trust_proxy_enabled() { + if let Some(ip) = forwarded_for_ip(req.headers()) { + return ip; + } + if let Some(ip) = real_ip(req.headers()) { + return ip; + } + } + + connect_info_ip(req).unwrap_or_else(|| "unknown".to_string()) +} + +fn trust_proxy_enabled() -> bool { + static CACHE: std::sync::OnceLock = std::sync::OnceLock::new(); + *CACHE.get_or_init(|| { + matches!( + std::env::var("TRUST_PROXY").as_deref(), + Ok("1") | Ok("true") | Ok("yes") + ) + }) +} + +fn forwarded_for_ip(headers: &axum::http::HeaderMap) -> Option { + let value = headers.get("x-forwarded-for")?.to_str().ok()?; + let first = value.split(',').next()?.trim(); + if first.is_empty() { + None + } else { + validate_ip(first) + } +} + +fn real_ip(headers: &axum::http::HeaderMap) -> Option { + let value = headers.get("x-real-ip")?.to_str().ok()?; + let ip = value.trim(); + if ip.is_empty() { None } else { validate_ip(ip) } +} + +/// Validate that a string is a valid IP address. +/// Returns Some(ip) if valid, None otherwise. +fn validate_ip(s: &str) -> Option { + s.parse::().ok().map(|ip| ip.to_string()) +} + +fn connect_info_ip(req: &Request) -> Option { + req.extensions() + .get::>() + .map(|c| c.0.ip().to_string()) +} diff --git a/crates/secrets-mcp/src/error.rs b/crates/secrets-mcp/src/error.rs index 09adda3..2311d11 100644 --- a/crates/secrets-mcp/src/error.rs +++ b/crates/secrets-mcp/src/error.rs @@ -23,6 +23,16 @@ pub fn app_error_to_mcp(err: &AppError) -> rmcp::ErrorData { "Entry not found. Use secrets_find to discover existing entries.", None, ), + AppError::NotFoundUser => rmcp::ErrorData::invalid_request("User not found.", None), + AppError::NotFoundSecret => rmcp::ErrorData::invalid_request("Secret not found.", None), + AppError::AuthenticationFailed => rmcp::ErrorData::invalid_request( + "Authentication failed. Please check your API key or login credentials.", + None, + ), + AppError::Unauthorized => rmcp::ErrorData::invalid_request( + "Unauthorized: you do not have permission to access this resource.", + None, + ), AppError::Validation { message } => rmcp::ErrorData::invalid_request(message.clone(), None), AppError::ConcurrentModification => rmcp::ErrorData::invalid_request( "The entry was modified by another request. Please refresh and try again.", @@ -32,6 +42,10 @@ pub fn app_error_to_mcp(err: &AppError) -> rmcp::ErrorData { "Decryption failed — the encryption key may be incorrect or does not match the data.", None, ), + AppError::EncryptionKeyNotSet => rmcp::ErrorData::invalid_request( + "Encryption key not set. You must set a passphrase before using this feature.", + None, + ), AppError::Internal(_) => rmcp::ErrorData::internal_error( "Request failed due to a server error. Check service logs if you need details.", None, diff --git a/crates/secrets-mcp/src/main.rs b/crates/secrets-mcp/src/main.rs index f3452e0..f924882 100644 --- a/crates/secrets-mcp/src/main.rs +++ b/crates/secrets-mcp/src/main.rs @@ -1,8 +1,11 @@ mod auth; +mod client_ip; mod error; mod logging; mod oauth; +mod rate_limit; mod tools; +mod validation; mod web; use std::net::SocketAddr; @@ -153,10 +156,43 @@ async fn main() -> Result<()> { ); // ── Router ──────────────────────────────────────────────────────────────── - let cors = CorsLayer::new() - .allow_origin(Any) - .allow_methods(Any) - .allow_headers(Any); + // CORS: restrict origins in production, allow all in development + let is_production = matches!( + load_env_var("SECRETS_ENV") + .as_deref() + .map(|s| s.to_ascii_lowercase()) + .as_deref(), + Some("prod" | "production") + ); + + let cors = if is_production { + // Only use the origin part (scheme://host:port) of BASE_URL for CORS. + // Browsers send Origin without path, so including a path would cause mismatches. + let allowed_origin = if let Ok(parsed) = base_url.parse::() { + let origin = parsed.origin().ascii_serialization(); + origin + .parse::() + .unwrap_or_else(|_| panic!("invalid BASE_URL origin: {}", origin)) + } else { + base_url + .parse::() + .unwrap_or_else(|_| panic!("invalid BASE_URL: {}", base_url)) + }; + CorsLayer::new() + .allow_origin(allowed_origin) + .allow_methods(Any) + .allow_headers(Any) + .allow_credentials(true) + } else { + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any) + }; + + // Rate limiting + let rate_limit_state = rate_limit::RateLimitState::new(); + let rate_limit_cleanup = rate_limit::spawn_cleanup_task(rate_limit_state.ip_limiter.clone()); let router = Router::new() .merge(web::web_router()) @@ -168,6 +204,10 @@ async fn main() -> Result<()> { pool, auth::bearer_auth_middleware, )) + .layer(axum::middleware::from_fn_with_state( + rate_limit_state.clone(), + rate_limit::rate_limit_middleware, + )) .layer(session_layer) .layer(cors) .with_state(app_state); @@ -192,12 +232,28 @@ async fn main() -> Result<()> { .context("server error")?; session_cleanup.abort(); + rate_limit_cleanup.abort(); Ok(()) } async fn shutdown_signal() { - tokio::signal::ctrl_c() - .await - .expect("failed to install CTRL+C signal handler"); + let ctrl_c = tokio::signal::ctrl_c(); + + #[cfg(unix)] + let terminate = async { + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + tracing::info!("Shutting down gracefully..."); } diff --git a/crates/secrets-mcp/src/rate_limit.rs b/crates/secrets-mcp/src/rate_limit.rs new file mode 100644 index 0000000..0e00312 --- /dev/null +++ b/crates/secrets-mcp/src/rate_limit.rs @@ -0,0 +1,160 @@ +use std::num::NonZeroU32; +use std::sync::Arc; +use std::time::Duration; + +use axum::{ + extract::{Request, State}, + http::{HeaderMap, HeaderValue, StatusCode}, + middleware::Next, + response::{IntoResponse, Response}, +}; +use governor::{ + Quota, RateLimiter, + clock::{Clock, DefaultClock}, + state::{InMemoryState, NotKeyed, keyed::DashMapStateStore}, +}; +use serde_json::json; + +use crate::client_ip; + +/// Per-IP rate limiter (keyed by client IP string) +type IpRateLimiter = RateLimiter, DefaultClock>; + +/// Global rate limiter (not keyed) +type GlobalRateLimiter = RateLimiter; + +/// Parse a u32 env value into NonZeroU32, logging a warning and falling back +/// to the default if the value is zero. +fn nz_or_log(value: u32, default: u32, name: &str) -> NonZeroU32 { + NonZeroU32::new(value).unwrap_or_else(|| { + tracing::warn!( + configured = value, + default, + "{name} must be non-zero, using default" + ); + NonZeroU32::new(default).unwrap() + }) +} + +#[derive(Clone)] +pub struct RateLimitState { + pub ip_limiter: Arc, + pub global_limiter: Arc, +} + +impl RateLimitState { + /// Create a new RateLimitState with default limits. + /// + /// Default limits (can be overridden via environment variables): + /// - Global: 100 req/s, burst 200 + /// - Per-IP: 20 req/s, burst 40 + pub fn new() -> Self { + let global_rate = std::env::var("RATE_LIMIT_GLOBAL_PER_SECOND") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(100); + + let global_burst = std::env::var("RATE_LIMIT_GLOBAL_BURST") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(200); + + let ip_rate = std::env::var("RATE_LIMIT_IP_PER_SECOND") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(20); + + let ip_burst = std::env::var("RATE_LIMIT_IP_BURST") + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(40); + + let global_rate_nz = nz_or_log(global_rate, 100, "RATE_LIMIT_GLOBAL_PER_SECOND"); + let global_burst_nz = nz_or_log(global_burst, 200, "RATE_LIMIT_GLOBAL_BURST"); + let ip_rate_nz = nz_or_log(ip_rate, 20, "RATE_LIMIT_IP_PER_SECOND"); + let ip_burst_nz = nz_or_log(ip_burst, 40, "RATE_LIMIT_IP_BURST"); + + let global_quota = Quota::per_second(global_rate_nz).allow_burst(global_burst_nz); + let ip_quota = Quota::per_second(ip_rate_nz).allow_burst(ip_burst_nz); + + tracing::info!( + global_rate = global_rate_nz.get(), + global_burst = global_burst_nz.get(), + ip_rate = ip_rate_nz.get(), + ip_burst = ip_burst_nz.get(), + "rate limiter initialized" + ); + + Self { + global_limiter: Arc::new(RateLimiter::direct(global_quota)), + ip_limiter: Arc::new(RateLimiter::dashmap(ip_quota)), + } + } +} + +/// Rate limiting middleware function. +/// +/// Checks both global and per-IP rate limits before allowing the request through. +/// Returns 429 Too Many Requests if either limit is exceeded. +pub async fn rate_limit_middleware( + State(rl): State, + req: Request, + next: Next, +) -> Result { + // Check global rate limit first + if let Err(negative) = rl.global_limiter.check() { + let retry_after = negative.wait_time_from(DefaultClock::default().now()); + tracing::warn!( + retry_after_secs = retry_after.as_secs(), + "global rate limit exceeded" + ); + return Err(too_many_requests_response(Some(retry_after))); + } + + // Check per-IP rate limit + let key = client_ip::extract_client_ip(&req); + if let Err(negative) = rl.ip_limiter.check_key(&key) { + let retry_after = negative.wait_time_from(DefaultClock::default().now()); + tracing::warn!( + client_ip = %key, + retry_after_secs = retry_after.as_secs(), + "per-IP rate limit exceeded" + ); + return Err(too_many_requests_response(Some(retry_after))); + } + + Ok(next.run(req).await) +} + +/// Start a background task to clean up expired rate limiter entries. +/// +/// This should be called once during application startup. +/// The task runs every 60 seconds and will be aborted on shutdown. +pub fn spawn_cleanup_task(ip_limiter: Arc) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let mut interval = tokio::time::interval(Duration::from_secs(60)); + loop { + interval.tick().await; + ip_limiter.retain_recent(); + } + }) +} + +/// Create a 429 Too Many Requests response. +fn too_many_requests_response(retry_after: Option) -> Response { + let mut headers = HeaderMap::new(); + headers.insert("Content-Type", HeaderValue::from_static("application/json")); + + if let Some(duration) = retry_after { + let secs = duration.as_secs().max(1); + if let Ok(value) = HeaderValue::from_str(&secs.to_string()) { + headers.insert("Retry-After", value); + } + } + + let body = json!({ + "error": "Too many requests, please try again later" + }); + + (StatusCode::TOO_MANY_REQUESTS, headers, body.to_string()).into_response() +} diff --git a/crates/secrets-mcp/src/tools.rs b/crates/secrets-mcp/src/tools.rs index a0482fb..28e8525 100644 --- a/crates/secrets-mcp/src/tools.rs +++ b/crates/secrets-mcp/src/tools.rs @@ -18,6 +18,8 @@ use serde_json::{Map, Value}; use sqlx::PgPool; use uuid::Uuid; +use crate::validation; + // ── Serde helpers for numeric parameters that may arrive as strings ────────── mod deser { @@ -593,6 +595,44 @@ fn map_to_kv_strings(map: Map) -> Vec { .collect() } +/// Check if any KV string would trigger a server-side file read. +/// +/// `parse_kv` in secrets-core supports two file-read syntaxes: +/// - `key=@path` (has `=`, value starts with `@`) +/// - `key@path` (no `=`, split on `@`) +/// +/// Both are legitimate for CLI usage but must be rejected in the MCP context +/// where the server process runs remotely and the caller controls the path. +/// +/// Note: `key:=json` is intentionally skipped here. Although the value may +/// contain `@` characters (e.g. `config:=@/etc/passwd`), the `:=` branch in +/// `parse_kv` treats the right-hand side as raw JSON and never performs file +/// reads. The `@` in such cases is just data, not a file reference. +fn contains_file_reference(entries: &[String]) -> Option { + for entry in entries { + // key:=json — safe, skip before checking for `=` + if entry.contains(":=") { + continue; + } + // key=@path + if let Some((_, value)) = entry.split_once('=') { + if value.starts_with('@') { + return Some(entry.clone()); + } + continue; + } + // key@path (no `=` present) + // parse_kv treats entries without `=` that contain `@` as file-read + // syntax (key@path). This includes strings like "user@example.com" + // if passed without a `=` separator — which is correct to reject here + // since the MCP server runs remotely and cannot read local files. + if entry.contains('@') { + return Some(entry.clone()); + } + } + None +} + /// Parse a UUID string, returning an MCP error on failure. fn parse_uuid(s: &str) -> Result { s.parse::() @@ -879,10 +919,33 @@ impl SecretsService { if let Some(obj) = input.meta_obj { meta.extend(map_to_kv_strings(obj)); } + if let Some(offending) = contains_file_reference(&meta) { + return Err(rmcp::ErrorData::invalid_params( + format!("@file syntax is not allowed in MCP tools: '{}'", offending), + None, + )); + } let mut secrets = input.secrets.unwrap_or_default(); if let Some(obj) = input.secrets_obj { secrets.extend(map_to_kv_strings(obj)); } + if let Some(offending) = contains_file_reference(&secrets) { + return Err(rmcp::ErrorData::invalid_params( + format!("@file syntax is not allowed in MCP tools: '{}'", offending), + None, + )); + } + + // Input length validation + validation::validate_input_lengths( + &input.name, + input.folder.as_deref(), + input.entry_type.as_deref(), + input.notes.as_deref(), + )?; + validation::validate_tags(&tags)?; + validation::validate_meta_entries(&meta)?; + let secret_types = input.secret_types.unwrap_or_default(); let secret_types_map: std::collections::HashMap = secret_types .into_iter() @@ -962,11 +1025,34 @@ impl SecretsService { if let Some(obj) = input.meta_obj { meta.extend(map_to_kv_strings(obj)); } + if let Some(offending) = contains_file_reference(&meta) { + return Err(rmcp::ErrorData::invalid_params( + format!("@file syntax is not allowed in MCP tools: '{}'", offending), + None, + )); + } let remove_meta = input.remove_meta.unwrap_or_default(); let mut secrets = input.secrets.unwrap_or_default(); if let Some(obj) = input.secrets_obj { secrets.extend(map_to_kv_strings(obj)); } + if let Some(offending) = contains_file_reference(&secrets) { + return Err(rmcp::ErrorData::invalid_params( + format!("@file syntax is not allowed in MCP tools: '{}'", offending), + None, + )); + } + + // Input length validation + validation::validate_input_lengths( + &input.name, + input.folder.as_deref(), + None, + input.notes.as_deref(), + )?; + validation::validate_tags(&add_tags)?; + validation::validate_meta_entries(&meta)?; + let secret_types = input.secret_types.unwrap_or_default(); let secret_types_map: std::collections::HashMap = secret_types .into_iter() diff --git a/crates/secrets-mcp/src/validation.rs b/crates/secrets-mcp/src/validation.rs new file mode 100644 index 0000000..02a0e30 --- /dev/null +++ b/crates/secrets-mcp/src/validation.rs @@ -0,0 +1,149 @@ +/// Validation constants for input field lengths. +pub const MAX_NAME_LENGTH: usize = 256; +pub const MAX_FOLDER_LENGTH: usize = 128; +pub const MAX_ENTRY_TYPE_LENGTH: usize = 64; +pub const MAX_NOTES_LENGTH: usize = 10000; +pub const MAX_TAG_LENGTH: usize = 64; +pub const MAX_TAG_COUNT: usize = 50; +pub const MAX_META_KEY_LENGTH: usize = 128; +pub const MAX_META_VALUE_LENGTH: usize = 4096; +pub const MAX_META_COUNT: usize = 100; + +/// Validate input field lengths for MCP tools. +/// +/// Returns an error if any field exceeds its maximum length. +pub fn validate_input_lengths( + name: &str, + folder: Option<&str>, + entry_type: Option<&str>, + notes: Option<&str>, +) -> Result<(), rmcp::ErrorData> { + if name.chars().count() > MAX_NAME_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!("name must be at most {} characters", MAX_NAME_LENGTH), + None, + )); + } + if let Some(folder) = folder + && folder.chars().count() > MAX_FOLDER_LENGTH + { + return Err(rmcp::ErrorData::invalid_params( + format!("folder must be at most {} characters", MAX_FOLDER_LENGTH), + None, + )); + } + if let Some(entry_type) = entry_type + && entry_type.chars().count() > MAX_ENTRY_TYPE_LENGTH + { + return Err(rmcp::ErrorData::invalid_params( + format!("type must be at most {} characters", MAX_ENTRY_TYPE_LENGTH), + None, + )); + } + if let Some(notes) = notes + && notes.chars().count() > MAX_NOTES_LENGTH + { + return Err(rmcp::ErrorData::invalid_params( + format!("notes must be at most {} characters", MAX_NOTES_LENGTH), + None, + )); + } + Ok(()) +} + +/// Validate the tags list. +/// +/// Checks total count and per-tag character length. +pub fn validate_tags(tags: &[String]) -> Result<(), rmcp::ErrorData> { + if tags.len() > MAX_TAG_COUNT { + return Err(rmcp::ErrorData::invalid_params( + format!("at most {} tags are allowed", MAX_TAG_COUNT), + None, + )); + } + for tag in tags { + if tag.chars().count() > MAX_TAG_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!( + "tag '{}' exceeds the maximum length of {} characters", + tag, MAX_TAG_LENGTH + ), + None, + )); + } + } + Ok(()) +} + +/// Validate metadata KV strings (key=value / key:=json format). +/// +/// Checks total count and per-key/per-value character lengths. +/// This is a best-effort check on the raw KV strings before parsing; +/// keys containing `:` path separators are checked as a whole. +pub fn validate_meta_entries(entries: &[String]) -> Result<(), rmcp::ErrorData> { + if entries.len() > MAX_META_COUNT { + return Err(rmcp::ErrorData::invalid_params( + format!("at most {} metadata entries are allowed", MAX_META_COUNT), + None, + )); + } + for entry in entries { + // key:=json — check both key and JSON value length + if let Some((key, value)) = entry.split_once(":=") { + if key.chars().count() > MAX_META_KEY_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!( + "metadata key '{}' exceeds the maximum length of {} characters", + key, MAX_META_KEY_LENGTH + ), + None, + )); + } + if value.chars().count() > MAX_META_VALUE_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!( + "metadata JSON value for key '{}' exceeds the maximum length of {} characters", + key, MAX_META_VALUE_LENGTH + ), + None, + )); + } + continue; + } + // key=value or key@path + if let Some((key, value)) = entry.split_once('=') { + if key.chars().count() > MAX_META_KEY_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!( + "metadata key '{}' exceeds the maximum length of {} characters", + key, MAX_META_KEY_LENGTH + ), + None, + )); + } + if value.chars().count() > MAX_META_VALUE_LENGTH { + return Err(rmcp::ErrorData::invalid_params( + format!( + "metadata value for key '{}' exceeds the maximum length of {} characters", + key, MAX_META_VALUE_LENGTH + ), + None, + )); + } + } else { + // Fallback: entry without = or := — check total length + let max_total = MAX_META_KEY_LENGTH + MAX_META_VALUE_LENGTH; + if entry.chars().count() > max_total { + let preview = entry.chars().take(50).collect::(); + return Err(rmcp::ErrorData::invalid_params( + format!( + "metadata entry '{}' exceeds the maximum length of {} characters", + preview, max_total + ), + None, + )); + } + } + } + Ok(()) +} diff --git a/crates/secrets-mcp/src/web.rs b/crates/secrets-mcp/src/web.rs index 1249917..78cb49e 100644 --- a/crates/secrets-mcp/src/web.rs +++ b/crates/secrets-mcp/src/web.rs @@ -1134,10 +1134,16 @@ fn map_app_error(err: &AppError, lang: UiLang) -> EntryApiError { StatusCode::CONFLICT, Json(json!({ "error": err.to_string() })), ), - AppError::NotFoundEntry => ( + AppError::NotFoundEntry | AppError::NotFoundUser | AppError::NotFoundSecret => ( StatusCode::NOT_FOUND, Json( - json!({ "error": tr(lang, "条目不存在或无权访问", "條目不存在或無權存取", "Entry not found or no access") }), + json!({ "error": tr(lang, "资源不存在或无权访问", "資源不存在或無權存取", "Resource not found or no access") }), + ), + ), + AppError::AuthenticationFailed | AppError::Unauthorized => ( + StatusCode::UNAUTHORIZED, + Json( + json!({ "error": tr(lang, "认证失败或无权访问", "認證失敗或無權存取", "Authentication failed or unauthorized") }), ), ), AppError::Validation { message } => { @@ -1155,6 +1161,12 @@ fn map_app_error(err: &AppError, lang: UiLang) -> EntryApiError { json!({ "error": tr(lang, "解密失败,请检查密码短语", "解密失敗,請檢查密碼短語", "Decryption failed — please check your passphrase") }), ), ), + AppError::EncryptionKeyNotSet => ( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": tr(lang, "请先设置密码短语后再使用此功能", "請先設定密碼短語再使用此功能", "Please set a passphrase before using this feature") }), + ), + ), AppError::Internal(_) => { tracing::error!(error = %err, "internal error in entry mutation"); ( diff --git a/crates/secrets-mcp/templates/audit.html b/crates/secrets-mcp/templates/audit.html index 5e7724d..549608d 100644 --- a/crates/secrets-mcp/templates/audit.html +++ b/crates/secrets-mcp/templates/audit.html @@ -50,8 +50,7 @@ .main { padding: 32px 24px 40px; flex: 1; } .card { background: var(--surface); border: 1px solid var(--border); border-radius: 12px; padding: 24px; width: 100%; max-width: 1180px; margin: 0 auto; } - .card-title { font-size: 20px; font-weight: 600; margin-bottom: 8px; } - .card-subtitle { color: var(--text-muted); font-size: 13px; margin-bottom: 20px; } + .card-title { font-size: 20px; font-weight: 600; margin-bottom: 20px; } .empty { color: var(--text-muted); font-size: 14px; padding: 20px 0; } table { width: 100%; border-collapse: collapse; } th, td { text-align: left; vertical-align: top; padding: 12px 10px; border-top: 1px solid var(--border); } @@ -115,7 +114,6 @@
我的审计
-
展示最近 100 条与当前用户相关的新审计记录。时间为浏览器本地时区。
{% if entries.is_empty() %}
暂无审计记录。
@@ -149,9 +147,9 @@