use std::time::Instant; use axum::{ body::{Body, Bytes, to_bytes}, extract::Request, http::{ HeaderMap, Method, StatusCode, header::{AUTHORIZATION, CONTENT_LENGTH, CONTENT_TYPE, USER_AGENT}, }, middleware::Next, response::{IntoResponse, Response}, }; use crate::auth::AuthUser; /// Axum middleware that logs structured info for every HTTP request. /// /// All requests: method, path, status, latency_ms, client_ip, user_agent. /// POST /mcp requests: additionally parses JSON-RPC body for jsonrpc_method, /// tool_name, jsonrpc_id, mcp_session, batch_size, tool_args (non-sensitive /// arguments only), plus masked auth_key / enc_key fingerprints and user_id /// for diagnosing header forwarding issues. /// /// Sensitive headers (Authorization, X-Encryption-Key) are never logged in /// full — only short fingerprints are emitted. pub async fn request_logging_middleware(req: Request, next: Next) -> Response { let method = req.method().clone(); let path = req.uri().path().to_string(); let ip = client_ip(&req); let ua = header_str(req.headers(), USER_AGENT); let content_len = header_str(req.headers(), CONTENT_LENGTH).and_then(|v| v.parse::().ok()); let mcp_session = req .headers() .get("mcp-session-id") .or_else(|| req.headers().get("x-mcp-session")) .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); // Capture header fingerprints before consuming the request. let auth_key = mask_bearer(req.headers()); let enc_key = mask_enc_key(req.headers()); let is_mcp_post = path.starts_with("/mcp") && method == Method::POST; let is_json = header_str(req.headers(), CONTENT_TYPE) .map(|ct| ct.contains("application/json")) .unwrap_or(false); let start = Instant::now(); // For MCP JSON-RPC POST requests, buffer body to extract JSON-RPC metadata. // We cap at 512 KiB to avoid buffering large payloads. if is_mcp_post && is_json { let cap = content_len.unwrap_or(0); if cap <= 512 * 1024 { let (parts, body) = req.into_parts(); // user_id is available after auth middleware has run (injected into extensions). let user_id = parts .extensions .get::() .map(|a| a.user_id.to_string()); match to_bytes(body, 512 * 1024).await { Ok(bytes) => { let rpc = parse_jsonrpc_meta(&bytes); let req = Request::from_parts(parts, Body::from(bytes)); let resp = next.run(req).await; let status = resp.status().as_u16(); let elapsed = start.elapsed().as_millis(); log_mcp_request( &method, &path, status, elapsed, ip.as_deref(), ua.as_deref(), content_len, mcp_session.as_deref(), auth_key.as_deref(), &enc_key, user_id.as_deref(), &rpc, ); return resp; } Err(e) => { tracing::warn!(path, error = %e, "failed to buffer MCP request body for logging"); let elapsed = start.elapsed().as_millis(); tracing::info!( method = method.as_str(), path, status = StatusCode::INTERNAL_SERVER_ERROR.as_u16(), elapsed_ms = elapsed, client_ip = ip.as_deref(), ua = ua.as_deref(), content_length = content_len, mcp_session = mcp_session.as_deref(), auth_key = auth_key.as_deref(), enc_key = enc_key.as_str(), user_id = user_id.as_deref(), "mcp request", ); return ( StatusCode::INTERNAL_SERVER_ERROR, "failed to read request body", ) .into_response(); } } } } let resp = next.run(req).await; let status = resp.status().as_u16(); let elapsed = start.elapsed().as_millis(); // Known client probe patterns that legitimately 404 — downgrade to debug to // avoid noise in production logs. These are: // • GET /.well-known/* — OAuth/OIDC discovery by MCP clients (RFC 8414 / RFC 9728) // • GET /mcp → 404 — old SSE-transport compatibility probe by clients let is_expected_probe_404 = status == 404 && (path.starts_with("/.well-known/") || (method == Method::GET && path.starts_with("/mcp"))); if is_expected_probe_404 { tracing::debug!( method = method.as_str(), path, status, elapsed_ms = elapsed, client_ip = ip.as_deref(), ua = ua.as_deref(), "probe request (not found — expected)", ); } else { log_http_request( &method, &path, status, elapsed, ip.as_deref(), ua.as_deref(), content_len, ); } resp } // ── Logging helpers ─────────────────────────────────────────────────────────── fn log_http_request( method: &Method, path: &str, status: u16, elapsed_ms: u128, client_ip: Option<&str>, ua: Option<&str>, content_length: Option, ) { tracing::info!( method = method.as_str(), path, status, elapsed_ms, client_ip, ua, content_length, "http request", ); } #[allow(clippy::too_many_arguments)] fn log_mcp_request( method: &Method, path: &str, status: u16, elapsed_ms: u128, client_ip: Option<&str>, ua: Option<&str>, content_length: Option, mcp_session: Option<&str>, auth_key: Option<&str>, enc_key: &str, user_id: Option<&str>, rpc: &JsonRpcMeta, ) { tracing::info!( method = method.as_str(), path, status, elapsed_ms, client_ip, ua, content_length, mcp_session, jsonrpc = rpc.rpc_method.as_deref(), tool = rpc.tool_name.as_deref(), jsonrpc_id = rpc.request_id.as_deref(), batch_size = rpc.batch_size, tool_args = rpc.tool_args.as_deref(), auth_key, enc_key, user_id, "mcp request", ); } // ── Sensitive header masking ────────────────────────────────────────────────── /// Mask a Bearer token: emit only the first 12 characters followed by `…`. /// Returns `None` if the Authorization header is absent or not a Bearer token. /// Example: `sk_90c88844e4e5…` fn mask_bearer(headers: &HeaderMap) -> Option { let val = headers.get(AUTHORIZATION)?.to_str().ok()?; let token = val.strip_prefix("Bearer ")?.trim(); if token.is_empty() { return None; } if token.len() > 12 { Some(format!("{}…", &token[..12])) } else { Some(token.to_string()) } } /// Fingerprint the X-Encryption-Key header. /// /// Emits first 4 chars, last 4 chars, and raw byte length, e.g. `146b…5516(64)`. /// Returns `"absent"` when the header is missing. Reveals enough to confirm /// which key arrived and whether it was truncated or padded, without revealing /// the full value. fn mask_enc_key(headers: &HeaderMap) -> String { match headers .get("x-encryption-key") .and_then(|v| v.to_str().ok()) { Some(val) => { let raw_len = val.len(); let t = val.trim(); let len = t.len(); if len >= 8 { let prefix = &t[..4]; let suffix = &t[len - 4..]; if raw_len != len { // Trailing/leading whitespace detected — extra diagnostic. format!("{prefix}…{suffix}({len}, raw={raw_len})") } else { format!("{prefix}…{suffix}({len})") } } else { format!("…({len})") } } None => "absent".to_string(), } } // ── JSON-RPC body parsing ───────────────────────────────────────────────────── /// Safe (non-sensitive) argument keys that may be included verbatim in logs. /// Keys NOT in this list (e.g. `secrets`, `secrets_obj`, `meta_obj`, /// `encryption_key`) are silently dropped. const SAFE_ARG_KEYS: &[&str] = &[ "id", "name", "name_query", "folder", "type", "entry_type", "field", "query", "tags", "limit", "offset", "format", "dry_run", "prefix", ]; #[derive(Debug, Default)] struct JsonRpcMeta { request_id: Option, rpc_method: Option, tool_name: Option, batch_size: Option, /// Non-sensitive tool call arguments for diagnostic logging. tool_args: Option, } fn parse_jsonrpc_meta(bytes: &Bytes) -> JsonRpcMeta { let Ok(value) = serde_json::from_slice::(bytes) else { return JsonRpcMeta::default(); }; if let Some(arr) = value.as_array() { // Batch request: summarise method(s) from first element only let first = arr.first().map(parse_single).unwrap_or_default(); return JsonRpcMeta { batch_size: Some(arr.len()), ..first }; } parse_single(&value) } fn parse_single(value: &serde_json::Value) -> JsonRpcMeta { let request_id = value.get("id").and_then(json_to_string); let rpc_method = value .get("method") .and_then(|v| v.as_str()) .map(|s| s.to_string()); let tool_name = value .pointer("/params/name") .and_then(|v| v.as_str()) .map(|s| s.to_string()); let tool_args = extract_tool_args(value); JsonRpcMeta { request_id, rpc_method, tool_name, batch_size: None, tool_args, } } /// Extract a compact summary of non-sensitive tool arguments for logging. /// Only keys listed in `SAFE_ARG_KEYS` are included. fn extract_tool_args(value: &serde_json::Value) -> Option { let args = value.pointer("/params/arguments")?; let obj = args.as_object()?; let pairs: Vec = obj .iter() .filter(|(k, v)| SAFE_ARG_KEYS.contains(&k.as_str()) && !v.is_null()) .map(|(k, v)| format!("{}={}", k, summarize_value(v))) .collect(); if pairs.is_empty() { None } else { Some(pairs.join(" ")) } } /// Produce a short, log-safe representation of a JSON value. fn summarize_value(v: &serde_json::Value) -> String { match v { serde_json::Value::String(s) => { if s.len() > 64 { format!("\"{}…\"", &s[..64]) } else { format!("\"{s}\"") } } serde_json::Value::Array(arr) => format!("[…{}]", arr.len()), serde_json::Value::Object(_) => "{…}".to_string(), other => other.to_string(), } } fn json_to_string(value: &serde_json::Value) -> Option { match value { serde_json::Value::Null => None, serde_json::Value::String(s) => Some(s.clone()), serde_json::Value::Number(n) => Some(n.to_string()), serde_json::Value::Bool(b) => Some(b.to_string()), other => Some(other.to_string()), } } // ── Header helpers ──────────────────────────────────────────────────────────── fn header_str(headers: &HeaderMap, name: impl axum::http::header::AsHeaderName) -> Option { headers .get(name) .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()) } fn client_ip(req: &Request) -> Option { crate::client_ip::extract_client_ip(req).into() }