refactor: workspace secrets-core + secrets-mcp MCP SaaS
- Split library (db/crypto/service) and MCP/Web/OAuth binary - Add deploy examples and CI/docs updates Made-with: Cursor
This commit is contained in:
114
crates/secrets-mcp/src/auth.rs
Normal file
114
crates/secrets-mcp/src/auth.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use axum::{
|
||||
extract::{ConnectInfo, Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use secrets_core::service::api_key::validate_api_key;
|
||||
|
||||
/// Injected into request extensions after Bearer token validation.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct AuthUser {
|
||||
pub user_id: Uuid,
|
||||
}
|
||||
|
||||
fn log_client_ip(req: &Request) -> Option<String> {
|
||||
if let Some(first) = req
|
||||
.headers()
|
||||
.get("x-forwarded-for")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.split(',').next())
|
||||
{
|
||||
let s = first.trim();
|
||||
if !s.is_empty() {
|
||||
return Some(s.to_string());
|
||||
}
|
||||
}
|
||||
req.extensions()
|
||||
.get::<ConnectInfo<SocketAddr>>()
|
||||
.map(|c| c.ip().to_string())
|
||||
}
|
||||
|
||||
/// Axum middleware that validates Bearer API keys for the /mcp route.
|
||||
/// Passes all non-MCP paths through without authentication.
|
||||
pub async fn bearer_auth_middleware(
|
||||
State(pool): State<PgPool>,
|
||||
req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
let method = req.method().as_str();
|
||||
let client_ip = log_client_ip(&req);
|
||||
|
||||
// Only authenticate /mcp paths
|
||||
if !path.starts_with("/mcp") {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
// Allow OPTIONS (CORS preflight) through
|
||||
if req.method() == axum::http::Method::OPTIONS {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
let auth_header = req
|
||||
.headers()
|
||||
.get(axum::http::header::AUTHORIZATION)
|
||||
.and_then(|v| v.to_str().ok());
|
||||
|
||||
let raw_key = match auth_header {
|
||||
Some(h) if h.starts_with("Bearer ") => h.trim_start_matches("Bearer ").trim(),
|
||||
Some(_) => {
|
||||
tracing::warn!(
|
||||
method,
|
||||
path,
|
||||
client_ip = client_ip.as_deref(),
|
||||
"invalid Authorization header format on /mcp (expected Bearer …)"
|
||||
);
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
None => {
|
||||
tracing::warn!(
|
||||
method,
|
||||
path,
|
||||
client_ip = client_ip.as_deref(),
|
||||
"missing Authorization header on /mcp"
|
||||
);
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
};
|
||||
|
||||
match validate_api_key(&pool, raw_key).await {
|
||||
Ok(Some(user_id)) => {
|
||||
tracing::debug!(?user_id, "api key authenticated");
|
||||
let mut req = req;
|
||||
req.extensions_mut().insert(AuthUser { user_id });
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
Ok(None) => {
|
||||
tracing::warn!(
|
||||
method,
|
||||
path,
|
||||
client_ip = client_ip.as_deref(),
|
||||
key_prefix = %&raw_key.chars().take(12).collect::<String>(),
|
||||
key_len = raw_key.len(),
|
||||
"invalid api key (not found in database — e.g. revoked key or DB was reset; update MCP client Bearer token)"
|
||||
);
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(
|
||||
method,
|
||||
path,
|
||||
client_ip = client_ip.as_deref(),
|
||||
error = %e,
|
||||
"api key validation error"
|
||||
);
|
||||
Err(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
}
|
||||
}
|
||||
155
crates/secrets-mcp/src/main.rs
Normal file
155
crates/secrets-mcp/src/main.rs
Normal file
@@ -0,0 +1,155 @@
|
||||
mod auth;
|
||||
mod oauth;
|
||||
mod tools;
|
||||
mod web;
|
||||
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use axum::Router;
|
||||
use rmcp::transport::streamable_http_server::{
|
||||
StreamableHttpService, session::local::LocalSessionManager,
|
||||
};
|
||||
use sqlx::PgPool;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tower_sessions::cookie::SameSite;
|
||||
use tower_sessions::{MemoryStore, SessionManagerLayer};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use secrets_core::config::resolve_db_url;
|
||||
use secrets_core::db::{create_pool, migrate};
|
||||
|
||||
use crate::oauth::OAuthConfig;
|
||||
use crate::tools::SecretsService;
|
||||
|
||||
/// Shared application state injected into web routes and middleware.
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub pool: PgPool,
|
||||
pub google_config: Option<OAuthConfig>,
|
||||
pub base_url: String,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
fn load_env_var(name: &str) -> Option<String> {
|
||||
std::env::var(name).ok().filter(|s| !s.is_empty())
|
||||
}
|
||||
|
||||
fn load_oauth_config(prefix: &str, base_url: &str, path: &str) -> Option<OAuthConfig> {
|
||||
let client_id = load_env_var(&format!("{}_CLIENT_ID", prefix))?;
|
||||
let client_secret = load_env_var(&format!("{}_CLIENT_SECRET", prefix))?;
|
||||
Some(OAuthConfig {
|
||||
client_id,
|
||||
client_secret,
|
||||
redirect_uri: format!("{}{}", base_url, path),
|
||||
})
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Load .env if present
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| "secrets_mcp=info".into()),
|
||||
)
|
||||
.init();
|
||||
|
||||
// ── Database ──────────────────────────────────────────────────────────────
|
||||
let db_url = resolve_db_url("")
|
||||
.context("Database not configured. Set SECRETS_DATABASE_URL environment variable.")?;
|
||||
let pool = create_pool(&db_url)
|
||||
.await
|
||||
.context("failed to connect to database")?;
|
||||
migrate(&pool)
|
||||
.await
|
||||
.context("failed to run database migrations")?;
|
||||
tracing::info!("Database connected and migrated");
|
||||
|
||||
// ── Configuration ─────────────────────────────────────────────────────────
|
||||
let base_url = load_env_var("BASE_URL").unwrap_or_else(|| "http://localhost:9315".to_string());
|
||||
let bind_addr = load_env_var("SECRETS_MCP_BIND").unwrap_or_else(|| "0.0.0.0:9315".to_string());
|
||||
|
||||
// ── OAuth providers ───────────────────────────────────────────────────────
|
||||
let google_config = load_oauth_config("GOOGLE", &base_url, "/auth/google/callback");
|
||||
|
||||
if google_config.is_none() {
|
||||
tracing::warn!(
|
||||
"No OAuth providers configured. Set GOOGLE_CLIENT_ID/GOOGLE_CLIENT_SECRET to enable login."
|
||||
);
|
||||
}
|
||||
|
||||
// ── Session store ─────────────────────────────────────────────────────────
|
||||
let session_store = MemoryStore::default();
|
||||
// Strict would drop the session cookie on redirect from Google → our origin (cross-site nav).
|
||||
let session_layer = SessionManagerLayer::new(session_store)
|
||||
.with_secure(base_url.starts_with("https://"))
|
||||
.with_same_site(SameSite::Lax);
|
||||
|
||||
// ── App state ─────────────────────────────────────────────────────────────
|
||||
let app_state = AppState {
|
||||
pool: pool.clone(),
|
||||
google_config,
|
||||
base_url: base_url.clone(),
|
||||
http_client: reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(15))
|
||||
.build()
|
||||
.context("failed to build HTTP client")?,
|
||||
};
|
||||
|
||||
// ── MCP service ───────────────────────────────────────────────────────────
|
||||
let pool_arc = Arc::new(pool.clone());
|
||||
|
||||
let mcp_service = StreamableHttpService::new(
|
||||
move || {
|
||||
let p = pool_arc.clone();
|
||||
Ok(SecretsService::new(p))
|
||||
},
|
||||
LocalSessionManager::default().into(),
|
||||
Default::default(),
|
||||
);
|
||||
|
||||
// ── Router ────────────────────────────────────────────────────────────────
|
||||
let cors = CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
let router = Router::new()
|
||||
.merge(web::web_router())
|
||||
.nest_service("/mcp", mcp_service)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
pool,
|
||||
auth::bearer_auth_middleware,
|
||||
))
|
||||
.layer(session_layer)
|
||||
.layer(cors)
|
||||
.with_state(app_state);
|
||||
|
||||
// ── Start server ──────────────────────────────────────────────────────────
|
||||
let listener = tokio::net::TcpListener::bind(&bind_addr)
|
||||
.await
|
||||
.with_context(|| format!("failed to bind to {}", bind_addr))?;
|
||||
|
||||
tracing::info!("Secrets MCP Server listening on http://{}", bind_addr);
|
||||
tracing::info!("MCP endpoint: {}/mcp", base_url);
|
||||
|
||||
axum::serve(
|
||||
listener,
|
||||
router.into_make_service_with_connect_info::<SocketAddr>(),
|
||||
)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.context("server error")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
tokio::signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install CTRL+C signal handler");
|
||||
tracing::info!("Shutting down gracefully...");
|
||||
}
|
||||
66
crates/secrets-mcp/src/oauth/google.rs
Normal file
66
crates/secrets-mcp/src/oauth/google.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use anyhow::{Context, Result};
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::{OAuthConfig, OAuthUserInfo};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
#[allow(dead_code)]
|
||||
id_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct UserInfo {
|
||||
sub: String,
|
||||
email: Option<String>,
|
||||
name: Option<String>,
|
||||
picture: Option<String>,
|
||||
}
|
||||
|
||||
/// Exchange authorization code for tokens and fetch user profile.
|
||||
pub async fn exchange_code(
|
||||
client: &reqwest::Client,
|
||||
config: &OAuthConfig,
|
||||
code: &str,
|
||||
) -> Result<OAuthUserInfo> {
|
||||
let token_resp: TokenResponse = client
|
||||
.post("https://oauth2.googleapis.com/token")
|
||||
.form(&[
|
||||
("code", code),
|
||||
("client_id", &config.client_id),
|
||||
("client_secret", &config.client_secret),
|
||||
("redirect_uri", &config.redirect_uri),
|
||||
("grant_type", "authorization_code"),
|
||||
])
|
||||
.send()
|
||||
.await
|
||||
.context("failed to exchange Google code")?
|
||||
.error_for_status()
|
||||
.context("Google token endpoint error")?
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse Google token response")?;
|
||||
|
||||
let user: UserInfo = client
|
||||
.get("https://openidconnect.googleapis.com/v1/userinfo")
|
||||
.bearer_auth(&token_resp.access_token)
|
||||
.send()
|
||||
.await
|
||||
.context("failed to fetch Google userinfo")?
|
||||
.error_for_status()
|
||||
.context("Google userinfo endpoint error")?
|
||||
.json()
|
||||
.await
|
||||
.context("failed to parse Google userinfo")?;
|
||||
|
||||
Ok(OAuthUserInfo {
|
||||
provider: "google".to_string(),
|
||||
provider_id: user.sub,
|
||||
email: user.email,
|
||||
name: user.name,
|
||||
avatar_url: user.picture,
|
||||
})
|
||||
}
|
||||
45
crates/secrets-mcp/src/oauth/mod.rs
Normal file
45
crates/secrets-mcp/src/oauth/mod.rs
Normal file
@@ -0,0 +1,45 @@
|
||||
pub mod google;
|
||||
pub mod wechat; // not yet implemented — placeholder for future WeChat integration
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Normalized OAuth user profile from any provider.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OAuthUserInfo {
|
||||
pub provider: String,
|
||||
pub provider_id: String,
|
||||
pub email: Option<String>,
|
||||
pub name: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
/// OAuth provider configuration.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OAuthConfig {
|
||||
pub client_id: String,
|
||||
pub client_secret: String,
|
||||
pub redirect_uri: String,
|
||||
}
|
||||
|
||||
/// Build the Google authorization URL.
|
||||
pub fn google_auth_url(config: &OAuthConfig, state: &str) -> String {
|
||||
format!(
|
||||
"https://accounts.google.com/o/oauth2/v2/auth\
|
||||
?client_id={}\
|
||||
&redirect_uri={}\
|
||||
&response_type=code\
|
||||
&scope=openid%20email%20profile\
|
||||
&state={}\
|
||||
&access_type=offline",
|
||||
urlencoding::encode(&config.client_id),
|
||||
urlencoding::encode(&config.redirect_uri),
|
||||
urlencoding::encode(state),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn random_state() -> String {
|
||||
use rand::RngExt;
|
||||
let mut bytes = [0u8; 16];
|
||||
rand::rng().fill(&mut bytes);
|
||||
bytes.iter().map(|b| format!("{:02x}", b)).collect()
|
||||
}
|
||||
18
crates/secrets-mcp/src/oauth/wechat.rs
Normal file
18
crates/secrets-mcp/src/oauth/wechat.rs
Normal file
@@ -0,0 +1,18 @@
|
||||
use super::{OAuthConfig, OAuthUserInfo};
|
||||
/// WeChat OAuth — not yet implemented.
|
||||
///
|
||||
/// This module is a placeholder for future WeChat Open Platform integration.
|
||||
/// When ready, implement `exchange_code` following the non-standard WeChat OAuth 2.0 flow:
|
||||
/// - Token exchange uses a GET request (not POST)
|
||||
/// - Preferred user identifier is `unionid` (cross-app), falling back to `openid`
|
||||
/// - Docs: https://developers.weixin.qq.com/doc/oplatform/Website_App/WeChat_Login/Wechat_Login.html
|
||||
use anyhow::{Result, bail};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub async fn exchange_code(
|
||||
_client: &reqwest::Client,
|
||||
_config: &OAuthConfig,
|
||||
_code: &str,
|
||||
) -> Result<OAuthUserInfo> {
|
||||
bail!("WeChat login is not yet implemented")
|
||||
}
|
||||
609
crates/secrets-mcp/src/tools.rs
Normal file
609
crates/secrets-mcp/src/tools.rs
Normal file
@@ -0,0 +1,609 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use rmcp::{
|
||||
RoleServer, ServerHandler,
|
||||
handler::server::wrapper::Parameters,
|
||||
model::{
|
||||
CallToolResult, Content, Implementation, InitializeResult, ProtocolVersion,
|
||||
ServerCapabilities,
|
||||
},
|
||||
service::RequestContext,
|
||||
tool, tool_handler, tool_router,
|
||||
};
|
||||
use schemars::JsonSchema;
|
||||
use serde::Deserialize;
|
||||
use sqlx::PgPool;
|
||||
use uuid::Uuid;
|
||||
|
||||
use secrets_core::service::{
|
||||
add::{AddParams, run as svc_add},
|
||||
delete::{DeleteParams, run as svc_delete},
|
||||
export::{ExportParams, export as svc_export},
|
||||
get_secret::{get_all_secrets, get_secret_field},
|
||||
history::run as svc_history,
|
||||
rollback::run as svc_rollback,
|
||||
search::{SearchParams, run as svc_search},
|
||||
update::{UpdateParams, run as svc_update},
|
||||
};
|
||||
|
||||
use crate::auth::AuthUser;
|
||||
|
||||
// ── Shared state ──────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SecretsService {
|
||||
pub pool: Arc<PgPool>,
|
||||
pub tool_router: rmcp::handler::server::router::tool::ToolRouter<SecretsService>,
|
||||
}
|
||||
|
||||
impl SecretsService {
|
||||
pub fn new(pool: Arc<PgPool>) -> Self {
|
||||
Self {
|
||||
pool,
|
||||
tool_router: Self::tool_router(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract user_id from the HTTP request parts injected by auth middleware.
|
||||
fn user_id_from_ctx(ctx: &RequestContext<RoleServer>) -> Result<Option<Uuid>, rmcp::ErrorData> {
|
||||
let parts = ctx
|
||||
.extensions
|
||||
.get::<http::request::Parts>()
|
||||
.ok_or_else(|| rmcp::ErrorData::internal_error("Missing HTTP parts", None))?;
|
||||
Ok(parts.extensions.get::<AuthUser>().map(|a| a.user_id))
|
||||
}
|
||||
|
||||
/// Get the authenticated user_id (returns error if not authenticated).
|
||||
fn require_user_id(ctx: &RequestContext<RoleServer>) -> Result<Uuid, rmcp::ErrorData> {
|
||||
let parts = ctx
|
||||
.extensions
|
||||
.get::<http::request::Parts>()
|
||||
.ok_or_else(|| rmcp::ErrorData::internal_error("Missing HTTP parts", None))?;
|
||||
parts
|
||||
.extensions
|
||||
.get::<AuthUser>()
|
||||
.map(|a| a.user_id)
|
||||
.ok_or_else(|| rmcp::ErrorData::invalid_request("Unauthorized: API key required", None))
|
||||
}
|
||||
|
||||
/// Extract the 32-byte encryption key from the X-Encryption-Key request header.
|
||||
/// The header value must be 64 lowercase hex characters (PBKDF2-derived key).
|
||||
fn extract_enc_key(ctx: &RequestContext<RoleServer>) -> Result<[u8; 32], rmcp::ErrorData> {
|
||||
let parts = ctx
|
||||
.extensions
|
||||
.get::<http::request::Parts>()
|
||||
.ok_or_else(|| rmcp::ErrorData::internal_error("Missing HTTP parts", None))?;
|
||||
let hex_str = parts
|
||||
.headers
|
||||
.get("x-encryption-key")
|
||||
.ok_or_else(|| {
|
||||
rmcp::ErrorData::invalid_request(
|
||||
"Missing X-Encryption-Key header. \
|
||||
Set this to your 64-char hex encryption key derived from your passphrase.",
|
||||
None,
|
||||
)
|
||||
})?
|
||||
.to_str()
|
||||
.map_err(|_| {
|
||||
rmcp::ErrorData::invalid_request("Invalid X-Encryption-Key header value", None)
|
||||
})?;
|
||||
secrets_core::crypto::extract_key_from_hex(hex_str)
|
||||
.map_err(|e| rmcp::ErrorData::invalid_request(e.to_string(), None))
|
||||
}
|
||||
|
||||
/// Require both user_id and encryption key.
|
||||
fn require_user_and_key(
|
||||
ctx: &RequestContext<RoleServer>,
|
||||
) -> Result<(Uuid, [u8; 32]), rmcp::ErrorData> {
|
||||
let user_id = Self::require_user_id(ctx)?;
|
||||
let key = Self::extract_enc_key(ctx)?;
|
||||
Ok((user_id, key))
|
||||
}
|
||||
}
|
||||
|
||||
// ── Tool parameter types ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct SearchInput {
|
||||
#[schemars(description = "Namespace filter (e.g. 'refining', 'ricnsmart')")]
|
||||
namespace: Option<String>,
|
||||
#[schemars(description = "Kind filter (e.g. 'server', 'service', 'key')")]
|
||||
kind: Option<String>,
|
||||
#[schemars(description = "Exact record name")]
|
||||
name: Option<String>,
|
||||
#[schemars(description = "Tag filters (all must match)")]
|
||||
tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Fuzzy search across name, namespace, kind, tags, metadata")]
|
||||
query: Option<String>,
|
||||
#[schemars(description = "Return only summary fields (name/tags/desc/updated_at)")]
|
||||
summary: Option<bool>,
|
||||
#[schemars(description = "Sort order: 'name' (default), 'updated', 'created'")]
|
||||
sort: Option<String>,
|
||||
#[schemars(description = "Max results (default 20)")]
|
||||
limit: Option<u32>,
|
||||
#[schemars(description = "Pagination offset (default 0)")]
|
||||
offset: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct GetSecretInput {
|
||||
#[schemars(description = "Namespace of the entry")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind of the entry (e.g. 'server', 'service')")]
|
||||
kind: String,
|
||||
#[schemars(description = "Name of the entry")]
|
||||
name: String,
|
||||
#[schemars(description = "Specific field to retrieve. If omitted, returns all fields.")]
|
||||
field: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct AddInput {
|
||||
#[schemars(description = "Namespace")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind (e.g. 'server', 'service', 'key')")]
|
||||
kind: String,
|
||||
#[schemars(description = "Unique name within namespace+kind")]
|
||||
name: String,
|
||||
#[schemars(description = "Tags for this entry")]
|
||||
tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Metadata fields as 'key=value' or 'key:=json' strings")]
|
||||
meta: Option<Vec<String>>,
|
||||
#[schemars(description = "Secret fields as 'key=value' strings")]
|
||||
secrets: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct UpdateInput {
|
||||
#[schemars(description = "Namespace")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind")]
|
||||
kind: String,
|
||||
#[schemars(description = "Name")]
|
||||
name: String,
|
||||
#[schemars(description = "Tags to add")]
|
||||
add_tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Tags to remove")]
|
||||
remove_tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Metadata fields to update/add as 'key=value' strings")]
|
||||
meta: Option<Vec<String>>,
|
||||
#[schemars(description = "Metadata field keys to remove")]
|
||||
remove_meta: Option<Vec<String>>,
|
||||
#[schemars(description = "Secret fields to update/add as 'key=value' strings")]
|
||||
secrets: Option<Vec<String>>,
|
||||
#[schemars(description = "Secret field keys to remove")]
|
||||
remove_secrets: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct DeleteInput {
|
||||
#[schemars(description = "Namespace")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind filter (required for single delete)")]
|
||||
kind: Option<String>,
|
||||
#[schemars(description = "Exact name to delete. Omit for bulk delete by namespace+kind.")]
|
||||
name: Option<String>,
|
||||
#[schemars(description = "Preview deletions without writing")]
|
||||
dry_run: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct HistoryInput {
|
||||
#[schemars(description = "Namespace")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind")]
|
||||
kind: String,
|
||||
#[schemars(description = "Name")]
|
||||
name: String,
|
||||
#[schemars(description = "Max history entries to return (default 20)")]
|
||||
limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct RollbackInput {
|
||||
#[schemars(description = "Namespace")]
|
||||
namespace: String,
|
||||
#[schemars(description = "Kind")]
|
||||
kind: String,
|
||||
#[schemars(description = "Name")]
|
||||
name: String,
|
||||
#[schemars(description = "Target version number. Omit to restore the most recent snapshot.")]
|
||||
to_version: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct ExportInput {
|
||||
#[schemars(description = "Namespace filter")]
|
||||
namespace: Option<String>,
|
||||
#[schemars(description = "Kind filter")]
|
||||
kind: Option<String>,
|
||||
#[schemars(description = "Exact name filter")]
|
||||
name: Option<String>,
|
||||
#[schemars(description = "Tag filters")]
|
||||
tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Fuzzy query")]
|
||||
query: Option<String>,
|
||||
#[schemars(description = "Export format: 'json' (default), 'toml', 'yaml'")]
|
||||
format: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, JsonSchema)]
|
||||
struct EnvMapInput {
|
||||
#[schemars(description = "Namespace filter")]
|
||||
namespace: Option<String>,
|
||||
#[schemars(description = "Kind filter")]
|
||||
kind: Option<String>,
|
||||
#[schemars(description = "Exact name filter")]
|
||||
name: Option<String>,
|
||||
#[schemars(description = "Tag filters")]
|
||||
tags: Option<Vec<String>>,
|
||||
#[schemars(description = "Only include these secret fields")]
|
||||
only_fields: Option<Vec<String>>,
|
||||
#[schemars(description = "Environment variable name prefix")]
|
||||
prefix: Option<String>,
|
||||
}
|
||||
|
||||
// ── Tool implementations ──────────────────────────────────────────────────────
|
||||
|
||||
#[tool_router]
|
||||
impl SecretsService {
|
||||
#[tool(
|
||||
description = "Search entries in the secrets store. Returns entries with metadata and \
|
||||
secret field names (not values). Use secrets_get to decrypt secret values."
|
||||
)]
|
||||
async fn secrets_search(
|
||||
&self,
|
||||
Parameters(input): Parameters<SearchInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let user_id = Self::user_id_from_ctx(&ctx)?;
|
||||
let tags = input.tags.unwrap_or_default();
|
||||
let result = svc_search(
|
||||
&self.pool,
|
||||
SearchParams {
|
||||
namespace: input.namespace.as_deref(),
|
||||
kind: input.kind.as_deref(),
|
||||
name: input.name.as_deref(),
|
||||
tags: &tags,
|
||||
query: input.query.as_deref(),
|
||||
sort: input.sort.as_deref().unwrap_or("name"),
|
||||
limit: input.limit.unwrap_or(20),
|
||||
offset: input.offset.unwrap_or(0),
|
||||
user_id,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let summary = input.summary.unwrap_or(false);
|
||||
let entries: Vec<serde_json::Value> = result
|
||||
.entries
|
||||
.iter()
|
||||
.map(|e| {
|
||||
if summary {
|
||||
serde_json::json!({
|
||||
"namespace": e.namespace,
|
||||
"kind": e.kind,
|
||||
"name": e.name,
|
||||
"tags": e.tags,
|
||||
"desc": e.metadata.get("desc").or_else(|| e.metadata.get("url"))
|
||||
.and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"updated_at": e.updated_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
})
|
||||
} else {
|
||||
let schema: Vec<&str> = result
|
||||
.secret_schemas
|
||||
.get(&e.id)
|
||||
.map(|f| f.iter().map(|s| s.field_name.as_str()).collect())
|
||||
.unwrap_or_default();
|
||||
serde_json::json!({
|
||||
"id": e.id,
|
||||
"namespace": e.namespace,
|
||||
"kind": e.kind,
|
||||
"name": e.name,
|
||||
"tags": e.tags,
|
||||
"metadata": e.metadata,
|
||||
"secret_fields": schema,
|
||||
"version": e.version,
|
||||
"updated_at": e.updated_at.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
})
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let json = serde_json::to_string_pretty(&entries).unwrap_or_else(|_| "[]".to_string());
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Get decrypted secret field values for an entry. Requires your \
|
||||
encryption key via X-Encryption-Key header (64 hex chars, PBKDF2-derived). \
|
||||
Returns all fields, or a specific field if 'field' is provided."
|
||||
)]
|
||||
async fn secrets_get(
|
||||
&self,
|
||||
Parameters(input): Parameters<GetSecretInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
|
||||
if let Some(field_name) = &input.field {
|
||||
let value = get_secret_field(
|
||||
&self.pool,
|
||||
&input.namespace,
|
||||
&input.kind,
|
||||
&input.name,
|
||||
field_name,
|
||||
&user_key,
|
||||
Some(user_id),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let result = serde_json::json!({ field_name: value });
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
} else {
|
||||
let secrets = get_all_secrets(
|
||||
&self.pool,
|
||||
&input.namespace,
|
||||
&input.kind,
|
||||
&input.name,
|
||||
&user_key,
|
||||
Some(user_id),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&secrets).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Add or upsert an entry with metadata and encrypted secret fields. \
|
||||
Requires X-Encryption-Key header. \
|
||||
Meta and secret values use 'key=value', 'key=@file', or 'key:=<json>' format."
|
||||
)]
|
||||
async fn secrets_add(
|
||||
&self,
|
||||
Parameters(input): Parameters<AddInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
|
||||
let tags = input.tags.unwrap_or_default();
|
||||
let meta = input.meta.unwrap_or_default();
|
||||
let secrets = input.secrets.unwrap_or_default();
|
||||
|
||||
let result = svc_add(
|
||||
&self.pool,
|
||||
AddParams {
|
||||
namespace: &input.namespace,
|
||||
kind: &input.kind,
|
||||
name: &input.name,
|
||||
tags: &tags,
|
||||
meta_entries: &meta,
|
||||
secret_entries: &secrets,
|
||||
user_id: Some(user_id),
|
||||
},
|
||||
&user_key,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Incrementally update an existing entry. Requires X-Encryption-Key header. \
|
||||
Only the fields you specify are changed; everything else is preserved."
|
||||
)]
|
||||
async fn secrets_update(
|
||||
&self,
|
||||
Parameters(input): Parameters<UpdateInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
|
||||
let add_tags = input.add_tags.unwrap_or_default();
|
||||
let remove_tags = input.remove_tags.unwrap_or_default();
|
||||
let meta = input.meta.unwrap_or_default();
|
||||
let remove_meta = input.remove_meta.unwrap_or_default();
|
||||
let secrets = input.secrets.unwrap_or_default();
|
||||
let remove_secrets = input.remove_secrets.unwrap_or_default();
|
||||
|
||||
let result = svc_update(
|
||||
&self.pool,
|
||||
UpdateParams {
|
||||
namespace: &input.namespace,
|
||||
kind: &input.kind,
|
||||
name: &input.name,
|
||||
add_tags: &add_tags,
|
||||
remove_tags: &remove_tags,
|
||||
meta_entries: &meta,
|
||||
remove_meta: &remove_meta,
|
||||
secret_entries: &secrets,
|
||||
remove_secrets: &remove_secrets,
|
||||
user_id: Some(user_id),
|
||||
},
|
||||
&user_key,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Delete one entry (specify namespace+kind+name) or bulk delete all \
|
||||
entries matching namespace+kind. Use dry_run=true to preview."
|
||||
)]
|
||||
async fn secrets_delete(
|
||||
&self,
|
||||
Parameters(input): Parameters<DeleteInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let user_id = Self::user_id_from_ctx(&ctx)?;
|
||||
|
||||
let result = svc_delete(
|
||||
&self.pool,
|
||||
DeleteParams {
|
||||
namespace: &input.namespace,
|
||||
kind: input.kind.as_deref(),
|
||||
name: input.name.as_deref(),
|
||||
dry_run: input.dry_run.unwrap_or(false),
|
||||
user_id,
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "View change history for an entry. Returns a list of versions with \
|
||||
actions and timestamps."
|
||||
)]
|
||||
async fn secrets_history(
|
||||
&self,
|
||||
Parameters(input): Parameters<HistoryInput>,
|
||||
_ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let result = svc_history(
|
||||
&self.pool,
|
||||
&input.namespace,
|
||||
&input.kind,
|
||||
&input.name,
|
||||
input.limit.unwrap_or(20),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Rollback an entry to a previous version. Requires X-Encryption-Key header. \
|
||||
Omit to_version to restore the most recent snapshot."
|
||||
)]
|
||||
async fn secrets_rollback(
|
||||
&self,
|
||||
Parameters(input): Parameters<RollbackInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
|
||||
let result = svc_rollback(
|
||||
&self.pool,
|
||||
&input.namespace,
|
||||
&input.kind,
|
||||
&input.name,
|
||||
input.to_version,
|
||||
&user_key,
|
||||
Some(user_id),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&result).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Export matching entries with decrypted secrets as JSON/TOML/YAML string. \
|
||||
Requires X-Encryption-Key header. Useful for backup or data migration."
|
||||
)]
|
||||
async fn secrets_export(
|
||||
&self,
|
||||
Parameters(input): Parameters<ExportInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
let tags = input.tags.unwrap_or_default();
|
||||
let format = input.format.as_deref().unwrap_or("json");
|
||||
|
||||
let data = svc_export(
|
||||
&self.pool,
|
||||
ExportParams {
|
||||
namespace: input.namespace.as_deref(),
|
||||
kind: input.kind.as_deref(),
|
||||
name: input.name.as_deref(),
|
||||
tags: &tags,
|
||||
query: input.query.as_deref(),
|
||||
no_secrets: false,
|
||||
user_id: Some(user_id),
|
||||
},
|
||||
Some(&user_key),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let serialized = format
|
||||
.parse::<secrets_core::models::ExportFormat>()
|
||||
.and_then(|fmt| fmt.serialize(&data))
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
Ok(CallToolResult::success(vec![Content::text(serialized)]))
|
||||
}
|
||||
|
||||
#[tool(
|
||||
description = "Preview the environment variable mapping that would be injected when \
|
||||
running a command. Requires X-Encryption-Key header. \
|
||||
Shows variable names and sources, useful for debugging."
|
||||
)]
|
||||
async fn secrets_env_map(
|
||||
&self,
|
||||
Parameters(input): Parameters<EnvMapInput>,
|
||||
ctx: RequestContext<RoleServer>,
|
||||
) -> Result<CallToolResult, rmcp::ErrorData> {
|
||||
let (user_id, user_key) = Self::require_user_and_key(&ctx)?;
|
||||
let tags = input.tags.unwrap_or_default();
|
||||
let only_fields = input.only_fields.unwrap_or_default();
|
||||
|
||||
let env_map = secrets_core::service::env_map::build_env_map(
|
||||
&self.pool,
|
||||
input.namespace.as_deref(),
|
||||
input.kind.as_deref(),
|
||||
input.name.as_deref(),
|
||||
&tags,
|
||||
&only_fields,
|
||||
input.prefix.as_deref().unwrap_or(""),
|
||||
&user_key,
|
||||
Some(user_id),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| rmcp::ErrorData::internal_error(e.to_string(), None))?;
|
||||
|
||||
let json = serde_json::to_string_pretty(&env_map).unwrap_or_default();
|
||||
Ok(CallToolResult::success(vec![Content::text(json)]))
|
||||
}
|
||||
}
|
||||
|
||||
// ── ServerHandler ─────────────────────────────────────────────────────────────
|
||||
|
||||
#[tool_handler]
|
||||
impl ServerHandler for SecretsService {
|
||||
fn get_info(&self) -> InitializeResult {
|
||||
let mut info = InitializeResult::new(ServerCapabilities::builder().enable_tools().build());
|
||||
info.server_info = Implementation::new("secrets-mcp", env!("CARGO_PKG_VERSION"));
|
||||
info.protocol_version = ProtocolVersion::V_2025_03_26;
|
||||
info.instructions = Some(
|
||||
"Manage cross-device secrets and configuration securely. \
|
||||
Data is encrypted with your passphrase-derived key. \
|
||||
Include your 64-char hex key in the X-Encryption-Key header for all read/write operations. \
|
||||
Use secrets_search to discover entries (no key needed), \
|
||||
secrets_get to decrypt secret values, \
|
||||
and secrets_add/secrets_update to write encrypted secrets."
|
||||
.to_string(),
|
||||
);
|
||||
info
|
||||
}
|
||||
}
|
||||
494
crates/secrets-mcp/src/web.rs
Normal file
494
crates/secrets-mcp/src/web.rs
Normal file
@@ -0,0 +1,494 @@
|
||||
use askama::Template;
|
||||
use axum::{
|
||||
Json, Router,
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{Html, IntoResponse, Redirect, Response},
|
||||
routing::{get, post},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tower_sessions::Session;
|
||||
use uuid::Uuid;
|
||||
|
||||
use secrets_core::crypto::hex;
|
||||
use secrets_core::service::{
|
||||
api_key::{ensure_api_key, regenerate_api_key},
|
||||
user::{
|
||||
OAuthProfile, bind_oauth_account, find_or_create_user, get_user_by_id,
|
||||
unbind_oauth_account, update_user_key_setup,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::AppState;
|
||||
use crate::oauth::{OAuthConfig, OAuthUserInfo, google_auth_url, random_state};
|
||||
|
||||
const SESSION_USER_ID: &str = "user_id";
|
||||
const SESSION_OAUTH_STATE: &str = "oauth_state";
|
||||
const SESSION_OAUTH_BIND_MODE: &str = "oauth_bind_mode";
|
||||
const SESSION_LOGIN_PROVIDER: &str = "login_provider";
|
||||
|
||||
// ── Template types ────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "login.html")]
|
||||
struct LoginTemplate {
|
||||
has_google: bool,
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "dashboard.html")]
|
||||
struct DashboardTemplate {
|
||||
user_name: String,
|
||||
user_email: String,
|
||||
has_passphrase: bool,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
// ── App state helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
fn google_cfg(state: &AppState) -> Option<&OAuthConfig> {
|
||||
state.google_config.as_ref()
|
||||
}
|
||||
|
||||
async fn current_user_id(session: &Session) -> Option<Uuid> {
|
||||
session
|
||||
.get::<String>(SESSION_USER_ID)
|
||||
.await
|
||||
.ok()
|
||||
.flatten()
|
||||
.and_then(|s| Uuid::parse_str(&s).ok())
|
||||
}
|
||||
|
||||
// ── Routes ────────────────────────────────────────────────────────────────────
|
||||
|
||||
pub fn web_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/", get(login_page))
|
||||
.route("/auth/google", get(auth_google))
|
||||
.route("/auth/google/callback", get(auth_google_callback))
|
||||
.route("/auth/logout", post(auth_logout))
|
||||
.route("/dashboard", get(dashboard))
|
||||
.route("/account/bind/google", get(account_bind_google))
|
||||
.route(
|
||||
"/account/bind/google/callback",
|
||||
get(account_bind_google_callback),
|
||||
)
|
||||
.route("/account/unbind/{provider}", post(account_unbind))
|
||||
.route("/api/key-salt", get(api_key_salt))
|
||||
.route("/api/key-setup", post(api_key_setup))
|
||||
.route("/api/apikey", get(api_apikey_get))
|
||||
.route("/api/apikey/regenerate", post(api_apikey_regenerate))
|
||||
}
|
||||
|
||||
// ── Login page ────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn login_page(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Response, StatusCode> {
|
||||
if let Some(_uid) = current_user_id(&session).await {
|
||||
return Ok(Redirect::to("/dashboard").into_response());
|
||||
}
|
||||
|
||||
let tmpl = LoginTemplate {
|
||||
has_google: state.google_config.is_some(),
|
||||
};
|
||||
render_template(tmpl)
|
||||
}
|
||||
|
||||
// ── Google OAuth ──────────────────────────────────────────────────────────────
|
||||
|
||||
async fn auth_google(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let config = google_cfg(&state).ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
|
||||
|
||||
let oauth_state = random_state();
|
||||
session
|
||||
.insert(SESSION_OAUTH_STATE, &oauth_state)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let url = google_auth_url(config, &oauth_state);
|
||||
Ok(Redirect::to(&url).into_response())
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct OAuthCallbackQuery {
|
||||
code: Option<String>,
|
||||
state: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
async fn auth_google_callback(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
Query(params): Query<OAuthCallbackQuery>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
handle_oauth_callback(&state, &session, params, "google", |s, cfg, code| {
|
||||
Box::pin(crate::oauth::google::exchange_code(
|
||||
&s.http_client,
|
||||
cfg,
|
||||
code,
|
||||
))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
// ── Shared OAuth callback handler ─────────────────────────────────────────────
|
||||
|
||||
async fn handle_oauth_callback<F>(
|
||||
state: &AppState,
|
||||
session: &Session,
|
||||
params: OAuthCallbackQuery,
|
||||
provider: &str,
|
||||
exchange_fn: F,
|
||||
) -> Result<Response, StatusCode>
|
||||
where
|
||||
F: for<'a> Fn(
|
||||
&'a AppState,
|
||||
&'a OAuthConfig,
|
||||
&'a str,
|
||||
) -> std::pin::Pin<
|
||||
Box<dyn std::future::Future<Output = anyhow::Result<OAuthUserInfo>> + Send + 'a>,
|
||||
>,
|
||||
{
|
||||
if let Some(err) = params.error {
|
||||
tracing::warn!(provider, error = %err, "OAuth error");
|
||||
return Ok(Redirect::to("/?error=oauth_error").into_response());
|
||||
}
|
||||
|
||||
let Some(code) = params.code else {
|
||||
tracing::warn!(provider, "OAuth callback missing code");
|
||||
return Ok(Redirect::to("/?error=oauth_missing_code").into_response());
|
||||
};
|
||||
let Some(returned_state) = params.state.as_deref() else {
|
||||
tracing::warn!(provider, "OAuth callback missing state");
|
||||
return Ok(Redirect::to("/?error=oauth_missing_state").into_response());
|
||||
};
|
||||
|
||||
let expected_state: Option<String> = session
|
||||
.get(SESSION_OAUTH_STATE)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
if expected_state.as_deref() != Some(returned_state) {
|
||||
tracing::warn!(
|
||||
provider,
|
||||
expected_present = expected_state.is_some(),
|
||||
"OAuth state mismatch (empty session often means SameSite=Strict or server restart)"
|
||||
);
|
||||
return Ok(Redirect::to("/?error=oauth_state").into_response());
|
||||
}
|
||||
session.remove::<String>(SESSION_OAUTH_STATE).await.ok();
|
||||
|
||||
let config = match provider {
|
||||
"google" => state
|
||||
.google_config
|
||||
.as_ref()
|
||||
.ok_or(StatusCode::SERVICE_UNAVAILABLE)?,
|
||||
_ => return Err(StatusCode::BAD_REQUEST),
|
||||
};
|
||||
|
||||
let user_info = exchange_fn(state, config, code.as_str())
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(provider, error = %e, "failed to exchange OAuth code");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
let bind_mode: bool = session
|
||||
.get(SESSION_OAUTH_BIND_MODE)
|
||||
.await
|
||||
.unwrap_or(None)
|
||||
.unwrap_or(false);
|
||||
|
||||
if bind_mode {
|
||||
let user_id = current_user_id(session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
session.remove::<bool>(SESSION_OAUTH_BIND_MODE).await.ok();
|
||||
|
||||
let profile = OAuthProfile {
|
||||
provider: user_info.provider,
|
||||
provider_id: user_info.provider_id,
|
||||
email: user_info.email,
|
||||
name: user_info.name,
|
||||
avatar_url: user_info.avatar_url,
|
||||
};
|
||||
|
||||
bind_oauth_account(&state.pool, user_id, profile)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "failed to bind OAuth account");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
return Ok(Redirect::to("/dashboard?bound=1").into_response());
|
||||
}
|
||||
|
||||
let profile = OAuthProfile {
|
||||
provider: user_info.provider,
|
||||
provider_id: user_info.provider_id,
|
||||
email: user_info.email,
|
||||
name: user_info.name,
|
||||
avatar_url: user_info.avatar_url,
|
||||
};
|
||||
|
||||
let (user, _is_new) = find_or_create_user(&state.pool, profile)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "failed to find or create user");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
// Ensure the user has an API key (auto-creates on first login).
|
||||
if let Err(e) = ensure_api_key(&state.pool, user.id).await {
|
||||
tracing::warn!(error = %e, "failed to ensure api key for user");
|
||||
}
|
||||
|
||||
session
|
||||
.insert(SESSION_USER_ID, user.id.to_string())
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
session
|
||||
.insert(SESSION_LOGIN_PROVIDER, &provider)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Redirect::to("/dashboard").into_response())
|
||||
}
|
||||
|
||||
// ── Logout ────────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn auth_logout(session: Session) -> impl IntoResponse {
|
||||
session.flush().await.ok();
|
||||
Redirect::to("/")
|
||||
}
|
||||
|
||||
// ── Dashboard ─────────────────────────────────────────────────────────────────
|
||||
|
||||
async fn dashboard(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user = get_user_by_id(&state.pool, user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let tmpl = DashboardTemplate {
|
||||
user_name: user.name.clone(),
|
||||
user_email: user.email.clone().unwrap_or_default(),
|
||||
has_passphrase: user.key_salt.is_some(),
|
||||
base_url: state.base_url.clone(),
|
||||
};
|
||||
|
||||
render_template(tmpl)
|
||||
}
|
||||
|
||||
// ── Account bind/unbind ───────────────────────────────────────────────────────
|
||||
|
||||
async fn account_bind_google(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let _ = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
session
|
||||
.insert(SESSION_OAUTH_BIND_MODE, true)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
let redirect_uri = format!("{}/account/bind/google/callback", state.base_url);
|
||||
let mut cfg = state
|
||||
.google_config
|
||||
.clone()
|
||||
.ok_or(StatusCode::SERVICE_UNAVAILABLE)?;
|
||||
cfg.redirect_uri = redirect_uri;
|
||||
let st = random_state();
|
||||
session.insert(SESSION_OAUTH_STATE, &st).await.ok();
|
||||
|
||||
Ok(Redirect::to(&google_auth_url(&cfg, &st)).into_response())
|
||||
}
|
||||
|
||||
async fn account_bind_google_callback(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
Query(params): Query<OAuthCallbackQuery>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
handle_oauth_callback(&state, &session, params, "google", |s, cfg, code| {
|
||||
Box::pin(crate::oauth::google::exchange_code(
|
||||
&s.http_client,
|
||||
cfg,
|
||||
code,
|
||||
))
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn account_unbind(
|
||||
State(state): State<AppState>,
|
||||
Path(provider): Path<String>,
|
||||
session: Session,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let current_login_provider = session
|
||||
.get::<String>(SESSION_LOGIN_PROVIDER)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
unbind_oauth_account(
|
||||
&state.pool,
|
||||
user_id,
|
||||
&provider,
|
||||
current_login_provider.as_deref(),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::warn!(error = %e, "failed to unbind oauth account");
|
||||
StatusCode::BAD_REQUEST
|
||||
})?;
|
||||
|
||||
Ok(Redirect::to("/dashboard?unbound=1").into_response())
|
||||
}
|
||||
|
||||
// ── Passphrase / Key setup API ────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct KeySaltResponse {
|
||||
has_passphrase: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
salt: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
key_check: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
async fn api_key_salt(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Json<KeySaltResponse>, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let user = get_user_by_id(&state.pool, user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if user.key_salt.is_none() {
|
||||
return Ok(Json(KeySaltResponse {
|
||||
has_passphrase: false,
|
||||
salt: None,
|
||||
key_check: None,
|
||||
params: None,
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(Json(KeySaltResponse {
|
||||
has_passphrase: true,
|
||||
salt: user.key_salt.as_deref().map(hex::encode_hex),
|
||||
key_check: user.key_check.as_deref().map(hex::encode_hex),
|
||||
params: user.key_params,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct KeySetupRequest {
|
||||
/// Hex-encoded 32-byte random salt
|
||||
salt: String,
|
||||
/// Hex-encoded AES-256-GCM encryption of "secrets-mcp-key-check" with the derived key
|
||||
key_check: String,
|
||||
/// Key derivation parameters, e.g. {"alg":"pbkdf2-sha256","iterations":600000}
|
||||
params: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct KeySetupResponse {
|
||||
ok: bool,
|
||||
}
|
||||
|
||||
async fn api_key_setup(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
Json(body): Json<KeySetupRequest>,
|
||||
) -> Result<Json<KeySetupResponse>, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let salt = hex::decode_hex(&body.salt).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
let key_check = hex::decode_hex(&body.key_check).map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
if salt.len() != 32 {
|
||||
return Err(StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
update_user_key_setup(&state.pool, user_id, &salt, &key_check, &body.params)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!(error = %e, "failed to update key setup");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
|
||||
Ok(Json(KeySetupResponse { ok: true }))
|
||||
}
|
||||
|
||||
// ── API Key management ────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ApiKeyResponse {
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
async fn api_apikey_get(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Json<ApiKeyResponse>, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let api_key = ensure_api_key(&state.pool, user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(ApiKeyResponse { api_key }))
|
||||
}
|
||||
|
||||
async fn api_apikey_regenerate(
|
||||
State(state): State<AppState>,
|
||||
session: Session,
|
||||
) -> Result<Json<ApiKeyResponse>, StatusCode> {
|
||||
let user_id = current_user_id(&session)
|
||||
.await
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
let api_key = regenerate_api_key(&state.pool, user_id)
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
|
||||
|
||||
Ok(Json(ApiKeyResponse { api_key }))
|
||||
}
|
||||
|
||||
// ── Helper ────────────────────────────────────────────────────────────────────
|
||||
|
||||
fn render_template<T: Template>(tmpl: T) -> Result<Response, StatusCode> {
|
||||
let html = tmpl.render().map_err(|e| {
|
||||
tracing::error!(error = %e, "template render error");
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
})?;
|
||||
Ok(Html(html).into_response())
|
||||
}
|
||||
Reference in New Issue
Block a user