use std::net::SocketAddr; use axum::{ extract::{ConnectInfo, Request, State}, http::StatusCode, middleware::Next, response::Response, }; use sqlx::PgPool; use uuid::Uuid; use secrets_core::service::api_key::validate_api_key; /// Injected into request extensions after Bearer token validation. #[derive(Clone, Debug)] pub struct AuthUser { pub user_id: Uuid, } fn log_client_ip(req: &Request) -> Option { if let Some(first) = req .headers() .get("x-forwarded-for") .and_then(|v| v.to_str().ok()) .and_then(|s| s.split(',').next()) { let s = first.trim(); if !s.is_empty() { return Some(s.to_string()); } } req.extensions() .get::>() .map(|c| c.ip().to_string()) } /// Axum middleware that validates Bearer API keys for the /mcp route. /// Passes all non-MCP paths through without authentication. pub async fn bearer_auth_middleware( State(pool): State, req: Request, next: Next, ) -> Result { let path = req.uri().path(); let method = req.method().as_str(); let client_ip = log_client_ip(&req); // Only authenticate /mcp paths if !path.starts_with("/mcp") { return Ok(next.run(req).await); } // Allow OPTIONS (CORS preflight) through if req.method() == axum::http::Method::OPTIONS { return Ok(next.run(req).await); } let auth_header = req .headers() .get(axum::http::header::AUTHORIZATION) .and_then(|v| v.to_str().ok()); let raw_key = match auth_header { Some(h) if h.starts_with("Bearer ") => h.trim_start_matches("Bearer ").trim(), Some(_) => { tracing::warn!( method, path, client_ip = client_ip.as_deref(), "invalid Authorization header format on /mcp (expected Bearer …)" ); return Err(StatusCode::UNAUTHORIZED); } None => { tracing::warn!( method, path, client_ip = client_ip.as_deref(), "missing Authorization header on /mcp" ); return Err(StatusCode::UNAUTHORIZED); } }; match validate_api_key(&pool, raw_key).await { Ok(Some(user_id)) => { tracing::debug!(?user_id, "api key authenticated"); let mut req = req; req.extensions_mut().insert(AuthUser { user_id }); Ok(next.run(req).await) } Ok(None) => { tracing::warn!( method, path, client_ip = client_ip.as_deref(), key_prefix = %&raw_key.chars().take(12).collect::(), key_len = raw_key.len(), "invalid api key (not found in database — e.g. revoked key or DB was reset; update MCP client Bearer token)" ); Err(StatusCode::UNAUTHORIZED) } Err(e) => { tracing::error!( method, path, client_ip = client_ip.as_deref(), error = %e, "api key validation error" ); Err(StatusCode::INTERNAL_SERVER_ERROR) } } }