320 lines
11 KiB
Rust
320 lines
11 KiB
Rust
mod auth;
|
|
mod client_ip;
|
|
mod error;
|
|
mod logging;
|
|
mod oauth;
|
|
mod rate_limit;
|
|
mod tools;
|
|
mod validation;
|
|
mod web;
|
|
|
|
use std::net::SocketAddr;
|
|
use std::sync::Arc;
|
|
|
|
use anyhow::{Context, Result};
|
|
use axum::Router;
|
|
use rmcp::transport::streamable_http_server::{
|
|
StreamableHttpService, session::local::LocalSessionManager,
|
|
};
|
|
use sqlx::PgPool;
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
use tower_sessions::cookie::SameSite;
|
|
use tower_sessions::session_store::ExpiredDeletion;
|
|
use tower_sessions::{Expiry, SessionManagerLayer};
|
|
use tower_sessions_sqlx_store_chrono::PostgresStore;
|
|
use tracing_subscriber::EnvFilter;
|
|
use tracing_subscriber::fmt::time::FormatTime;
|
|
|
|
use secrets_core::config::resolve_db_config;
|
|
use secrets_core::db::{create_pool, migrate};
|
|
|
|
use crate::oauth::OAuthConfig;
|
|
use crate::tools::SecretsService;
|
|
|
|
/// Shared application state injected into web routes and middleware.
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub pool: PgPool,
|
|
pub google_config: Option<OAuthConfig>,
|
|
pub base_url: String,
|
|
pub http_client: reqwest::Client,
|
|
}
|
|
|
|
fn load_env_var(name: &str) -> Option<String> {
|
|
std::env::var(name).ok().filter(|s| !s.is_empty())
|
|
}
|
|
|
|
/// Pretty-print bind address in logs (`127.0.0.1` → `localhost`); actual socket bind unchanged.
|
|
fn listen_addr_log_display(bind_addr: &str) -> String {
|
|
bind_addr
|
|
.strip_prefix("127.0.0.1:")
|
|
.map(|port| format!("localhost:{port}"))
|
|
.unwrap_or_else(|| bind_addr.to_string())
|
|
}
|
|
|
|
fn load_oauth_config(prefix: &str, base_url: &str, path: &str) -> Option<OAuthConfig> {
|
|
let client_id = load_env_var(&format!("{}_CLIENT_ID", prefix))?;
|
|
let client_secret = load_env_var(&format!("{}_CLIENT_SECRET", prefix))?;
|
|
Some(OAuthConfig {
|
|
client_id,
|
|
client_secret,
|
|
redirect_uri: format!("{}{}", base_url, path),
|
|
})
|
|
}
|
|
|
|
/// Log line timestamps in the process local timezone (honors `TZ` / system zone).
|
|
#[derive(Clone, Copy, Default)]
|
|
struct LocalRfc3339Time;
|
|
|
|
impl FormatTime for LocalRfc3339Time {
|
|
fn format_time(&self, w: &mut tracing_subscriber::fmt::format::Writer<'_>) -> std::fmt::Result {
|
|
write!(
|
|
w,
|
|
"{}",
|
|
chrono::Local::now().to_rfc3339_opts(chrono::SecondsFormat::Millis, false)
|
|
)
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<()> {
|
|
// Load .env if present
|
|
let _ = dotenvy::dotenv();
|
|
|
|
tracing_subscriber::fmt()
|
|
.with_timer(LocalRfc3339Time)
|
|
.with_env_filter(
|
|
EnvFilter::try_from_default_env()
|
|
.unwrap_or_else(|_| "secrets_mcp=info,tower_http=info".into()),
|
|
)
|
|
.init();
|
|
|
|
// ── Database ──────────────────────────────────────────────────────────────
|
|
let db_config = resolve_db_config("")
|
|
.context("Database not configured. Set SECRETS_DATABASE_URL environment variable.")?;
|
|
let pool = create_pool(&db_config)
|
|
.await
|
|
.context("failed to connect to database")?;
|
|
migrate(&pool)
|
|
.await
|
|
.context("failed to run database migrations")?;
|
|
tracing::info!("Database connected and migrated");
|
|
|
|
// ── Configuration ─────────────────────────────────────────────────────────
|
|
let base_url = load_env_var("BASE_URL").unwrap_or_else(|| "http://localhost:9315".to_string());
|
|
let bind_addr =
|
|
load_env_var("SECRETS_MCP_BIND").unwrap_or_else(|| "127.0.0.1:9315".to_string());
|
|
|
|
// ── OAuth providers ───────────────────────────────────────────────────────
|
|
let google_config = load_oauth_config("GOOGLE", &base_url, "/auth/google/callback");
|
|
|
|
if google_config.is_none() {
|
|
tracing::warn!(
|
|
"No OAuth providers configured. Set GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET to enable login."
|
|
);
|
|
}
|
|
|
|
// ── Session store (PostgreSQL-backed) ─────────────────────────────────────
|
|
let session_store = PostgresStore::new(pool.clone());
|
|
session_store
|
|
.migrate()
|
|
.await
|
|
.context("failed to run session table migration")?;
|
|
// Prune expired rows every hour; task is aborted when the server shuts down.
|
|
let session_cleanup = tokio::spawn(
|
|
session_store
|
|
.clone()
|
|
.continuously_delete_expired(tokio::time::Duration::from_secs(3600)),
|
|
);
|
|
// Strict would drop the session cookie on redirect from Google → our origin (cross-site nav).
|
|
let session_layer = SessionManagerLayer::new(session_store)
|
|
.with_secure(base_url.starts_with("https://"))
|
|
.with_same_site(SameSite::Lax)
|
|
.with_expiry(Expiry::OnInactivity(time::Duration::days(14)));
|
|
|
|
// ── App state ─────────────────────────────────────────────────────────────
|
|
let app_state = AppState {
|
|
pool: pool.clone(),
|
|
google_config,
|
|
base_url: base_url.clone(),
|
|
http_client: reqwest::Client::builder()
|
|
.timeout(std::time::Duration::from_secs(15))
|
|
.build()
|
|
.context("failed to build HTTP client")?,
|
|
};
|
|
|
|
// ── MCP service ───────────────────────────────────────────────────────────
|
|
let pool_arc = Arc::new(pool.clone());
|
|
|
|
let mcp_service = StreamableHttpService::new(
|
|
move || {
|
|
let p = pool_arc.clone();
|
|
Ok(SecretsService::new(p))
|
|
},
|
|
LocalSessionManager::default().into(),
|
|
Default::default(),
|
|
);
|
|
|
|
// ── Router ────────────────────────────────────────────────────────────────
|
|
// CORS: restrict origins in production, allow all in development
|
|
let is_production = matches!(
|
|
load_env_var("SECRETS_ENV")
|
|
.as_deref()
|
|
.map(|s| s.to_ascii_lowercase())
|
|
.as_deref(),
|
|
Some("prod" | "production")
|
|
);
|
|
|
|
let cors = build_cors_layer(&base_url, is_production);
|
|
|
|
// Rate limiting
|
|
let rate_limit_state = rate_limit::RateLimitState::new();
|
|
let rate_limit_cleanup = rate_limit::spawn_cleanup_task(rate_limit_state.ip_limiter.clone());
|
|
|
|
let router = Router::new()
|
|
.merge(web::web_router())
|
|
.nest_service("/mcp", mcp_service)
|
|
.layer(axum::middleware::from_fn(
|
|
logging::request_logging_middleware,
|
|
))
|
|
.layer(axum::middleware::from_fn_with_state(
|
|
pool,
|
|
auth::bearer_auth_middleware,
|
|
))
|
|
.layer(axum::middleware::from_fn_with_state(
|
|
rate_limit_state.clone(),
|
|
rate_limit::rate_limit_middleware,
|
|
))
|
|
.layer(session_layer)
|
|
.layer(cors)
|
|
.with_state(app_state);
|
|
|
|
// ── Start server ──────────────────────────────────────────────────────────
|
|
let listener = tokio::net::TcpListener::bind(&bind_addr)
|
|
.await
|
|
.with_context(|| format!("failed to bind to {}", bind_addr))?;
|
|
|
|
tracing::info!(
|
|
"Secrets MCP Server listening on http://{}",
|
|
listen_addr_log_display(&bind_addr)
|
|
);
|
|
tracing::info!("MCP endpoint: {}/mcp", base_url);
|
|
|
|
axum::serve(
|
|
listener,
|
|
router.into_make_service_with_connect_info::<SocketAddr>(),
|
|
)
|
|
.with_graceful_shutdown(shutdown_signal())
|
|
.await
|
|
.context("server error")?;
|
|
|
|
session_cleanup.abort();
|
|
rate_limit_cleanup.abort();
|
|
Ok(())
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
let ctrl_c = tokio::signal::ctrl_c();
|
|
|
|
#[cfg(unix)]
|
|
let terminate = async {
|
|
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())
|
|
.expect("failed to install SIGTERM handler")
|
|
.recv()
|
|
.await;
|
|
};
|
|
|
|
#[cfg(not(unix))]
|
|
let terminate = std::future::pending::<()>();
|
|
|
|
tokio::select! {
|
|
_ = ctrl_c => {},
|
|
_ = terminate => {},
|
|
}
|
|
|
|
tracing::info!("Shutting down gracefully...");
|
|
}
|
|
|
|
/// Production CORS allowed headers.
|
|
///
|
|
/// When adding a new custom header to the MCP or Web API, this list must be
|
|
/// updated accordingly — otherwise browsers will block the request during
|
|
/// the CORS preflight check.
|
|
fn production_allowed_headers() -> [axum::http::HeaderName; 5] {
|
|
[
|
|
axum::http::header::AUTHORIZATION,
|
|
axum::http::header::CONTENT_TYPE,
|
|
axum::http::HeaderName::from_static("x-encryption-key"),
|
|
axum::http::HeaderName::from_static("mcp-session-id"),
|
|
axum::http::HeaderName::from_static("x-mcp-session"),
|
|
]
|
|
}
|
|
|
|
/// Build the CORS layer for the application.
|
|
///
|
|
/// In production mode the origin is restricted to the BASE_URL origin
|
|
/// (scheme://host:port, path stripped) and credentials are allowed.
|
|
/// `allow_headers` uses an explicit whitelist to avoid the tower-http
|
|
/// restriction on `allow_credentials(true)` + `allow_headers(Any)`.
|
|
///
|
|
/// In development mode all origins, methods and headers are allowed.
|
|
fn build_cors_layer(base_url: &str, is_production: bool) -> CorsLayer {
|
|
if is_production {
|
|
let allowed_origin = if let Ok(parsed) = base_url.parse::<url::Url>() {
|
|
let origin = parsed.origin().ascii_serialization();
|
|
origin
|
|
.parse::<axum::http::HeaderValue>()
|
|
.unwrap_or_else(|_| panic!("invalid BASE_URL origin: {}", origin))
|
|
} else {
|
|
base_url
|
|
.parse::<axum::http::HeaderValue>()
|
|
.unwrap_or_else(|_| panic!("invalid BASE_URL: {}", base_url))
|
|
};
|
|
CorsLayer::new()
|
|
.allow_origin(allowed_origin)
|
|
.allow_methods(Any)
|
|
.allow_headers(production_allowed_headers())
|
|
.allow_credentials(true)
|
|
} else {
|
|
CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_headers(Any)
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn production_cors_does_not_panic() {
|
|
let layer = build_cors_layer("https://secrets.example.com/app", true);
|
|
let _ = layer;
|
|
}
|
|
|
|
#[test]
|
|
fn production_cors_headers_include_all_required() {
|
|
let headers = production_allowed_headers();
|
|
let names: Vec<&str> = headers.iter().map(|h| h.as_str()).collect();
|
|
assert!(names.contains(&"authorization"));
|
|
assert!(names.contains(&"content-type"));
|
|
assert!(names.contains(&"x-encryption-key"));
|
|
assert!(names.contains(&"mcp-session-id"));
|
|
assert!(names.contains(&"x-mcp-session"));
|
|
}
|
|
|
|
#[test]
|
|
fn production_cors_normalizes_base_url_with_path() {
|
|
let url = url::Url::parse("https://secrets.example.com/secrets/app").unwrap();
|
|
let origin = url.origin().ascii_serialization();
|
|
assert_eq!(origin, "https://secrets.example.com");
|
|
}
|
|
|
|
#[test]
|
|
fn development_cors_allows_everything() {
|
|
let layer = build_cors_layer("http://localhost:9315", false);
|
|
let _ = layer;
|
|
}
|
|
}
|