feat: add secrets-mcp-local gateway (proxy, unlock cache, plaintext tool gate)

This commit is contained in:
voson
2026-04-12 12:46:15 +08:00
parent 0bf06bbc73
commit 34093b0e23
8 changed files with 555 additions and 3 deletions

View File

@@ -0,0 +1,25 @@
[package]
name = "secrets-mcp-local"
version = "0.1.0"
edition.workspace = true
description = "Local MCP gateway: caches unlock credentials and proxies to remote secrets-mcp /mcp"
license = "MIT OR Apache-2.0"
[[bin]]
name = "secrets-mcp-local"
path = "src/main.rs"
[dependencies]
anyhow.workspace = true
axum = "0.8"
futures-util = "0.3"
http = "1"
reqwest = { workspace = true, features = ["stream"] }
serde.workspace = true
serde_json.workspace = true
tokio.workspace = true
tower-http = { version = "0.6", features = ["cors", "limit"] }
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter"] }
dotenvy.workspace = true
url = "2"

View File

@@ -0,0 +1,450 @@
//! Local MCP gateway: single agent-facing MCP endpoint on localhost.
//!
//! Proxies JSON-RPC to `SECRETS_REMOTE_MCP_URL` and injects `Authorization` +
//! `X-Encryption-Key` from an in-memory unlock cache (TTL). Cursor can connect
//! without embedding the encryption key in its MCP config after a one-time
//! local unlock.
use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use axum::Router;
use axum::body::Body;
use axum::extract::State;
use axum::http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
use axum::response::{Html, IntoResponse, Response};
use axum::routing::{get, post};
use futures_util::TryStreamExt;
use serde::Deserialize;
use serde_json::json;
use tokio::sync::RwLock;
use tower_http::cors::CorsLayer;
use tracing_subscriber::EnvFilter;
use url::Url;
const DEFAULT_BIND: &str = "127.0.0.1:9316";
const DEFAULT_TTL_SECS: u64 = 3600;
/// Tools that return decrypted secret material; blocked when
/// `SECRETS_LOCAL_ALLOW_PLAINTEXT_TOOLS` is not `1`/`true`/`yes`.
const PLAINTEXT_TOOL_NAMES: &[&str] = &["secrets_get", "secrets_export", "secrets_env_map"];
#[derive(Clone)]
struct AppState {
remote_mcp_url: Url,
dashboard_hint_url: String,
http_client: reqwest::Client,
unlock: Arc<RwLock<Option<UnlockState>>>,
default_api_key: Option<String>,
ttl: Duration,
allow_plaintext_tools: bool,
}
struct UnlockState {
api_key: String,
encryption_key_hex: String,
expires_at: Instant,
}
#[derive(Debug, Deserialize)]
struct UnlockBody {
/// 64-char hex encryption key (PBKDF2-derived), same as remote `X-Encryption-Key`.
encryption_key: String,
/// Optional if `SECRETS_LOCAL_API_KEY` is set in the environment.
api_key: Option<String>,
/// Override TTL for this unlock (seconds).
#[serde(default)]
ttl_secs: Option<u64>,
}
fn load_env(name: &str) -> Option<String> {
std::env::var(name).ok().filter(|s| !s.is_empty())
}
fn parse_bool_env(name: &str, default: bool) -> bool {
match load_env(name).map(|s| s.to_ascii_lowercase()).as_deref() {
None => default,
Some("1" | "true" | "yes" | "on") => true,
Some("0" | "false" | "no" | "off") => false,
_ => default,
}
}
fn dashboard_url_from_remote(remote: &Url) -> String {
load_env("SECRETS_REMOTE_DASHBOARD_URL").unwrap_or_else(|| {
let mut u = remote.clone();
u.set_path("/dashboard");
u.set_query(None);
u.set_fragment(None);
u.to_string()
})
}
/// If JSON-RPC targets a blocked tool, return an error response body instead of forwarding.
fn maybe_block_plaintext_request(
allow_plaintext: bool,
method: &Method,
body: &[u8],
) -> Option<Vec<u8>> {
if allow_plaintext || *method != Method::POST || body.is_empty() {
return None;
}
let value: serde_json::Value = serde_json::from_slice(body).ok()?;
fn tool_blocked(name: &str) -> bool {
PLAINTEXT_TOOL_NAMES.contains(&name)
}
fn block_single(id: serde_json::Value, name: &str) -> serde_json::Value {
json!({
"jsonrpc": "2.0",
"id": id,
"error": {
"code": -32000,
"message": format!(
"Local gateway: tool `{name}` is disabled (set SECRETS_LOCAL_ALLOW_PLAINTEXT_TOOLS=1 to allow)."
)
}
})
}
match value {
serde_json::Value::Object(obj) => {
if obj.get("method").and_then(|m| m.as_str()) != Some("tools/call") {
return None;
}
let name = obj
.get("params")
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())?;
if !tool_blocked(name) {
return None;
}
let id = obj.get("id").cloned().unwrap_or(json!(null));
Some(block_single(id, name).to_string().into_bytes())
}
serde_json::Value::Array(arr) => {
let mut out = Vec::with_capacity(arr.len());
let mut changed = false;
for item in arr {
if let serde_json::Value::Object(ref obj) = item
&& obj.get("method").and_then(|m| m.as_str()) == Some("tools/call")
&& let Some(name) = obj
.get("params")
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())
&& tool_blocked(name)
{
changed = true;
let id = obj.get("id").cloned().unwrap_or(json!(null));
out.push(block_single(id, name));
continue;
}
out.push(item);
}
if changed {
serde_json::to_vec(&out).ok()
} else {
None
}
}
_ => None,
}
}
async fn index_html(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let remote = state.remote_mcp_url.as_str();
let dash = &state.dashboard_hint_url;
Html(format!(
r#"<!DOCTYPE html>
<html lang="zh-CN">
<head><meta charset="utf-8"><title>secrets-mcp-local</title></head>
<body>
<h1>本地 MCP Gateway</h1>
<p>远程 MCP: <code>{remote}</code></p>
<p>在浏览器打开 Dashboard 登录并复制 API Key<a href="{dash}">{dash}</a></p>
<p>然后在本机执行解锁(示例):</p>
<pre>curl -sS -X POST http://127.0.0.1:9316/local/unlock \
-H "Content-Type: application/json" \
-d '{{"encryption_key":"YOUR_64_HEX","api_key":"YOUR_API_KEY"}}'</pre>
<p>或将 Cursor MCP 指向 <code>http://127.0.0.1:9316/mcp</code>(无需在配置里写 <code>X-Encryption-Key</code>)。</p>
<p><a href="/local/status">/local/status</a></p>
</body>
</html>"#,
remote = remote,
dash = dash
))
}
async fn local_status(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let guard = state.unlock.read().await;
let now = Instant::now();
let body = match guard.as_ref() {
None => json!({ "unlocked": false }),
Some(u) if u.expires_at <= now => json!({ "unlocked": false, "reason": "expired" }),
Some(u) => json!({
"unlocked": true,
"expires_in_secs": u.expires_at.duration_since(now).as_secs(),
"allow_plaintext_tools": state.allow_plaintext_tools,
}),
};
(StatusCode::OK, axum::Json(body))
}
async fn local_unlock(
State(state): State<Arc<AppState>>,
axum::Json(body): axum::Json<UnlockBody>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let hex = body.encryption_key.trim();
if hex.len() != 64 || !hex.chars().all(|c| c.is_ascii_hexdigit()) {
return Err((
StatusCode::BAD_REQUEST,
"encryption_key must be 64 hex characters".to_string(),
));
}
let api_key = body
.api_key
.or_else(|| state.default_api_key.clone())
.filter(|s| !s.is_empty())
.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
"api_key required (or set SECRETS_LOCAL_API_KEY)".to_string(),
)
})?;
let ttl_secs = body.ttl_secs.unwrap_or(state.ttl.as_secs());
let ttl = Duration::from_secs(ttl_secs.clamp(60, 86400 * 7));
let expires_at = Instant::now() + ttl;
let mut guard = state.unlock.write().await;
*guard = Some(UnlockState {
api_key,
encryption_key_hex: hex.to_string(),
expires_at,
});
tracing::info!(
ttl_secs = ttl.as_secs(),
"local unlock: credentials cached until expiry"
);
Ok((
StatusCode::OK,
axum::Json(json!({
"ok": true,
"expires_in_secs": ttl.as_secs(),
})),
))
}
async fn local_lock(State(state): State<Arc<AppState>>) -> impl IntoResponse {
let mut guard = state.unlock.write().await;
*guard = None;
tracing::info!("local lock: credentials cleared");
(StatusCode::OK, axum::Json(json!({ "ok": true })))
}
fn header_value_copy(h: &axum::http::HeaderValue) -> Option<HeaderValue> {
HeaderValue::from_bytes(h.as_bytes()).ok()
}
async fn proxy_mcp(
State(state): State<Arc<AppState>>,
method: Method,
headers: HeaderMap,
body: Body,
) -> Result<Response, Infallible> {
let now = Instant::now();
let unlock = state.unlock.read().await;
let Some(u) = unlock.as_ref() else {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(
axum::http::header::CONTENT_TYPE,
"application/json; charset=utf-8",
)
.body(Body::from(
r#"{"error":"local gateway locked: POST /local/unlock first"}"#,
))
.unwrap());
};
if u.expires_at <= now {
drop(unlock);
let mut w = state.unlock.write().await;
*w = None;
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.header(
axum::http::header::CONTENT_TYPE,
"application/json; charset=utf-8",
)
.body(Body::from(
r#"{"error":"local gateway unlock expired: POST /local/unlock again"}"#,
))
.unwrap());
}
let api_key = u.api_key.clone();
let enc_key = u.encryption_key_hex.clone();
drop(unlock);
let bytes = match axum::body::to_bytes(body, 10 * 1024 * 1024).await {
Ok(b) => b.to_vec(),
Err(e) => {
tracing::warn!(error = %e, "read body failed");
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("body read failed"))
.unwrap());
}
};
let body_to_send = if let Some(blocked) =
maybe_block_plaintext_request(state.allow_plaintext_tools, &method, &bytes)
{
blocked
} else {
bytes
};
let mut req_builder = state
.http_client
.request(method.clone(), state.remote_mcp_url.as_str())
.body(body_to_send);
// Forward MCP session / accept headers from client.
for name in ["accept", "content-type", "mcp-session-id", "x-mcp-session"] {
if let Ok(hn) = HeaderName::from_bytes(name.as_bytes())
&& let Some(v) = headers.get(&hn)
&& let Some(copy) = header_value_copy(v)
{
req_builder = req_builder.header(hn, copy);
}
}
req_builder = req_builder
.header(
axum::http::header::AUTHORIZATION,
format!("Bearer {}", api_key),
)
.header("X-Encryption-Key", enc_key);
let upstream = match req_builder.send().await {
Ok(r) => r,
Err(e) => {
tracing::error!(error = %e, "upstream request failed");
return Ok(Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from(format!("upstream error: {e}")))
.unwrap());
}
};
let status = upstream.status();
let mut response_builder = Response::builder().status(status.as_u16());
for (key, value) in upstream.headers().iter() {
// Skip hop-by-hop headers if any; reqwest already decompresses.
let key_str = key.as_str();
if key_str.eq_ignore_ascii_case("transfer-encoding") {
continue;
}
if let Some(v) = header_value_copy(value) {
response_builder = response_builder.header(key, v);
}
}
let stream = upstream.bytes_stream().map_err(std::io::Error::other);
let body = Body::from_stream(stream);
Ok(response_builder.body(body).unwrap())
}
#[tokio::main]
async fn main() -> Result<()> {
let _ = dotenvy::dotenv();
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "secrets_mcp_local=info,tower_http=info".into()),
)
.init();
let remote_mcp_url = load_env("SECRETS_REMOTE_MCP_URL")
.context("SECRETS_REMOTE_MCP_URL is required (e.g. https://secrets.example.com/mcp)")?;
let remote_mcp_url: Url = remote_mcp_url
.parse()
.context("invalid SECRETS_REMOTE_MCP_URL")?;
let dashboard_hint_url = dashboard_url_from_remote(&remote_mcp_url);
let bind = load_env("SECRETS_MCP_LOCAL_BIND").unwrap_or_else(|| DEFAULT_BIND.to_string());
let default_api_key = load_env("SECRETS_LOCAL_API_KEY");
let ttl_secs: u64 = load_env("SECRETS_LOCAL_UNLOCK_TTL_SECS")
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_TTL_SECS);
let ttl = Duration::from_secs(ttl_secs);
let allow_plaintext_tools = parse_bool_env("SECRETS_LOCAL_ALLOW_PLAINTEXT_TOOLS", false);
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(120))
.build()
.context("failed to build HTTP client")?;
let state = Arc::new(AppState {
remote_mcp_url: remote_mcp_url.clone(),
dashboard_hint_url,
http_client,
unlock: Arc::new(RwLock::new(None)),
default_api_key,
ttl,
allow_plaintext_tools,
});
let app = Router::new()
.route("/", get(index_html))
.route("/local/unlock", post(local_unlock))
.route("/local/lock", post(local_lock))
.route("/local/status", get(local_status))
.route("/mcp", axum::routing::any(proxy_mcp))
.layer(
CorsLayer::new()
.allow_origin(tower_http::cors::Any)
.allow_methods(tower_http::cors::Any)
.allow_headers(tower_http::cors::Any),
)
.layer(tower_http::limit::RequestBodyLimitLayer::new(
10 * 1024 * 1024,
))
.with_state(state);
let addr: SocketAddr = bind
.parse()
.with_context(|| format!("invalid SECRETS_MCP_LOCAL_BIND: {bind}"))?;
tracing::info!(
bind = %addr,
remote = %remote_mcp_url,
allow_plaintext_tools = allow_plaintext_tools,
"secrets-mcp-local gateway"
);
tracing::info!("MCP (agent): http://{}/mcp", addr);
tracing::info!("Unlock: POST http://{}/local/unlock", addr);
let listener = tokio::net::TcpListener::bind(addr)
.await
.with_context(|| format!("failed to bind {addr}"))?;
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.context("server error")?;
Ok(())
}