mod auth; mod client_ip; mod error; mod logging; mod oauth; mod rate_limit; mod tools; mod validation; mod web; use std::net::SocketAddr; use std::sync::Arc; use anyhow::{Context, Result}; use axum::Router; use rmcp::transport::streamable_http_server::{ StreamableHttpService, session::local::LocalSessionManager, }; use sqlx::PgPool; use tower_http::cors::{Any, CorsLayer}; use tower_sessions::cookie::SameSite; use tower_sessions::session_store::ExpiredDeletion; use tower_sessions::{Expiry, SessionManagerLayer}; use tower_sessions_sqlx_store_chrono::PostgresStore; use tracing_subscriber::EnvFilter; use tracing_subscriber::fmt::time::FormatTime; use secrets_core::config::resolve_db_config; use secrets_core::db::{create_pool, migrate}; use crate::oauth::OAuthConfig; use crate::tools::SecretsService; /// Shared application state injected into web routes and middleware. #[derive(Clone)] pub struct AppState { pub pool: PgPool, pub google_config: Option, pub base_url: String, pub http_client: reqwest::Client, } fn load_env_var(name: &str) -> Option { std::env::var(name).ok().filter(|s| !s.is_empty()) } /// Pretty-print bind address in logs (`127.0.0.1` → `localhost`); actual socket bind unchanged. fn listen_addr_log_display(bind_addr: &str) -> String { bind_addr .strip_prefix("127.0.0.1:") .map(|port| format!("localhost:{port}")) .unwrap_or_else(|| bind_addr.to_string()) } fn load_oauth_config(prefix: &str, base_url: &str, path: &str) -> Option { let client_id = load_env_var(&format!("{}_CLIENT_ID", prefix))?; let client_secret = load_env_var(&format!("{}_CLIENT_SECRET", prefix))?; Some(OAuthConfig { client_id, client_secret, redirect_uri: format!("{}{}", base_url, path), }) } /// Log line timestamps in the process local timezone (honors `TZ` / system zone). #[derive(Clone, Copy, Default)] struct LocalRfc3339Time; impl FormatTime for LocalRfc3339Time { fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result { write!( w, "{}", chrono::Local::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, false) ) } } #[tokio::main] async fn main() -> Result<()> { // Load .env if present let _ = dotenvy::dotenv(); tracing_subscriber::fmt() .with_timer(LocalRfc3339Time) .with_env_filter( EnvFilter::try_from_default_env() .unwrap_or_else(|_| "secrets_mcp=info,tower_http=info".into()), ) .init(); // ── Database ────────────────────────────────────────────────────────────── let db_config = resolve_db_config("") .context("Database not configured. Set SECRETS_DATABASE_URL environment variable.")?; let pool = create_pool(&db_config) .await .context("failed to connect to database")?; migrate(&pool) .await .context("failed to run database migrations")?; tracing::info!("Database connected and migrated"); // ── Configuration ───────────────────────────────────────────────────────── let base_url = load_env_var("BASE_URL").unwrap_or_else(|| "http://localhost:9315".to_string()); let bind_addr = load_env_var("SECRETS_MCP_BIND").unwrap_or_else(|| "127.0.0.1:9315".to_string()); // ── OAuth providers ─────────────────────────────────────────────────────── let google_config = load_oauth_config("GOOGLE", &base_url, "/auth/google/callback"); if google_config.is_none() { tracing::warn!( "No OAuth providers configured. Set GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET to enable login." ); } // ── Session store (PostgreSQL-backed) ───────────────────────────────────── let session_store = PostgresStore::new(pool.clone()); session_store .migrate() .await .context("failed to run session table migration")?; // Prune expired rows every hour; task is aborted when the server shuts down. let session_cleanup = tokio::spawn( session_store .clone() .continuously_delete_expired(tokio::time::Duration::from_secs(3600)), ); // Strict would drop the session cookie on redirect from Google → our origin (cross-site nav). let session_layer = SessionManagerLayer::new(session_store) .with_secure(base_url.starts_with("https://")) .with_same_site(SameSite::Lax) .with_expiry(Expiry::OnInactivity(time::Duration::days(14))); // ── App state ───────────────────────────────────────────────────────────── let app_state = AppState { pool: pool.clone(), google_config, base_url: base_url.clone(), http_client: reqwest::Client::builder() .timeout(std::time::Duration::from_secs(15)) .build() .context("failed to build HTTP client")?, }; // ── MCP service ─────────────────────────────────────────────────────────── let pool_arc = Arc::new(pool.clone()); let mcp_service = StreamableHttpService::new( move || { let p = pool_arc.clone(); Ok(SecretsService::new(p)) }, LocalSessionManager::default().into(), Default::default(), ); // ── Router ──────────────────────────────────────────────────────────────── // CORS: restrict origins in production, allow all in development let is_production = matches!( load_env_var("SECRETS_ENV") .as_deref() .map(|s| s.to_ascii_lowercase()) .as_deref(), Some("prod" | "production") ); let cors = build_cors_layer(&base_url, is_production); // Rate limiting let rate_limit_state = rate_limit::RateLimitState::new(); let rate_limit_cleanup = rate_limit::spawn_cleanup_task(rate_limit_state.ip_limiter.clone()); let router = Router::new() .merge(web::web_router()) .nest_service("/mcp", mcp_service) .layer(axum::middleware::from_fn( logging::request_logging_middleware, )) .layer(axum::middleware::from_fn_with_state( pool, auth::bearer_auth_middleware, )) .layer(axum::middleware::from_fn_with_state( rate_limit_state.clone(), rate_limit::rate_limit_middleware, )) .layer(session_layer) .layer(cors) .with_state(app_state); // ── Start server ────────────────────────────────────────────────────────── let listener = tokio::net::TcpListener::bind(&bind_addr) .await .with_context(|| format!("failed to bind to {}", bind_addr))?; tracing::info!( "Secrets MCP Server listening on http://{}", listen_addr_log_display(&bind_addr) ); tracing::info!("MCP endpoint: {}/mcp", base_url); axum::serve( listener, router.into_make_service_with_connect_info::(), ) .with_graceful_shutdown(shutdown_signal()) .await .context("server error")?; session_cleanup.abort(); rate_limit_cleanup.abort(); Ok(()) } async fn shutdown_signal() { let ctrl_c = tokio::signal::ctrl_c(); #[cfg(unix)] let terminate = async { tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) .expect("failed to install SIGTERM handler") .recv() .await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } tracing::info!("Shutting down gracefully..."); } /// Production CORS allowed headers. /// /// When adding a new custom header to the MCP or Web API, this list must be /// updated accordingly — otherwise browsers will block the request during /// the CORS preflight check. fn production_allowed_headers() -> [axum::http::HeaderName; 5] { [ axum::http::header::AUTHORIZATION, axum::http::header::CONTENT_TYPE, axum::http::HeaderName::from_static("x-encryption-key"), axum::http::HeaderName::from_static("mcp-session-id"), axum::http::HeaderName::from_static("x-mcp-session"), ] } /// Build the CORS layer for the application. /// /// In production mode the origin is restricted to the BASE_URL origin /// (scheme://host:port, path stripped) and credentials are allowed. /// `allow_headers` uses an explicit whitelist to avoid the tower-http /// restriction on `allow_credentials(true)` + `allow_headers(Any)`. /// /// In development mode all origins, methods and headers are allowed. fn build_cors_layer(base_url: &str, is_production: bool) -> CorsLayer { if is_production { let allowed_origin = if let Ok(parsed) = base_url.parse::() { let origin = parsed.origin().ascii_serialization(); origin .parse::() .unwrap_or_else(|_| panic!("invalid BASE_URL origin: {}", origin)) } else { base_url .parse::() .unwrap_or_else(|_| panic!("invalid BASE_URL: {}", base_url)) }; CorsLayer::new() .allow_origin(allowed_origin) .allow_methods(Any) .allow_headers(production_allowed_headers()) .allow_credentials(true) } else { CorsLayer::new() .allow_origin(Any) .allow_methods(Any) .allow_headers(Any) } } #[cfg(test)] mod tests { use super::*; #[test] fn production_cors_does_not_panic() { let layer = build_cors_layer("https://secrets.example.com/app", true); let _ = layer; } #[test] fn production_cors_headers_include_all_required() { let headers = production_allowed_headers(); let names: Vec<&str> = headers.iter().map(|h| h.as_str()).collect(); assert!(names.contains(&"authorization")); assert!(names.contains(&"content-type")); assert!(names.contains(&"x-encryption-key")); assert!(names.contains(&"mcp-session-id")); assert!(names.contains(&"x-mcp-session")); } #[test] fn production_cors_normalizes_base_url_with_path() { let url = url::Url::parse("https://secrets.example.com/secrets/app").unwrap(); let origin = url.origin().ascii_serialization(); assert_eq!(origin, "https://secrets.example.com"); } #[test] fn development_cors_allows_everything() { let layer = build_cors_layer("http://localhost:9315", false); let _ = layer; } }