Files
secrets/crates/secrets-mcp/src/web/mod.rs
voson 59084a409d
All checks were successful
Secrets MCP — Build & Release / 检查 / 构建 / 发版 (push) Successful in 6m3s
Secrets MCP — Build & Release / 部署 secrets-mcp (push) Successful in 1m36s
release(secrets-mcp): 0.5.10 — Web 模块化、性能与错误处理
- 拆分 web.rs 为 web/ 子模块;统一 client_ip 提取
- core: user_scope SQL 复用、env_map N+1 消除、FETCH_ALL 上限调整
- entries 列表页并行查询;PgPool 去 Arc;结构化 NotFound 等错误
- CI: SSH 私钥安全写入;crypto/hex 与依赖清理;MCP 输入长度校验
- AGENTS: API Key 明文存储设计说明
2026-04-06 23:41:07 +08:00

307 lines
10 KiB
Rust

use askama::Template;
use axum::{
Router,
http::{HeaderMap, StatusCode, header},
response::{Html, IntoResponse, Redirect, Response},
routing::{get, patch, post},
};
use tower_sessions::Session;
use uuid::Uuid;
use crate::AppState;
use crate::oauth::OAuthConfig;
mod account;
mod assets;
mod audit;
mod auth;
mod entries;
// ── Session keys ──────────────────────────────────────────────────────────────
const SESSION_USER_ID: &str = "user_id";
const SESSION_OAUTH_STATE: &str = "oauth_state";
const SESSION_OAUTH_BIND_MODE: &str = "oauth_bind_mode";
const SESSION_LOGIN_PROVIDER: &str = "login_provider";
const SESSION_KEY_VERSION: &str = "key_version";
// ── Page limits ───────────────────────────────────────────────────────────────
/// Cap for HTML list (avoids loading unbounded rows into memory).
const ENTRIES_PAGE_LIMIT: u32 = 50;
const AUDIT_PAGE_LIMIT: i64 = 10;
// ── UI language ───────────────────────────────────────────────────────────────
#[derive(Clone, Copy)]
enum UiLang {
ZhCn,
ZhTw,
En,
}
fn request_ui_lang(headers: &HeaderMap) -> UiLang {
let Some(raw) = headers
.get(header::ACCEPT_LANGUAGE)
.and_then(|v| v.to_str().ok())
else {
return UiLang::ZhCn;
};
let lower = raw.to_ascii_lowercase();
if lower.contains("zh-tw") || lower.contains("zh-hk") || lower.contains("zh-hant") {
UiLang::ZhTw
} else if lower.contains("zh") {
UiLang::ZhCn
} else if lower.contains("en") {
UiLang::En
} else {
UiLang::ZhCn
}
}
fn tr(lang: UiLang, zh_cn: &'static str, zh_tw: &'static str, en: &'static str) -> &'static str {
match lang {
UiLang::ZhCn => zh_cn,
UiLang::ZhTw => zh_tw,
UiLang::En => en,
}
}
// ── App state helpers ─────────────────────────────────────────────────────────
fn google_cfg(state: &AppState) -> Option<&OAuthConfig> {
state.google_config.as_ref()
}
async fn current_user_id(session: &Session) -> Option<Uuid> {
match session.get::<String>(SESSION_USER_ID).await {
Ok(opt) => match opt {
Some(s) => match Uuid::parse_str(&s) {
Ok(id) => Some(id),
Err(e) => {
tracing::warn!(error = %e, user_id_str = %s, "invalid user_id UUID in session");
None
}
},
None => None,
},
Err(e) => {
tracing::warn!(error = %e, "failed to read user_id from session");
None
}
}
}
/// Load and validate the current user from session and DB.
///
/// Returns the user if the session is valid. Flushes the session and returns
/// `Err(Redirect::to("/login"))` when:
/// - the session has no `user_id`,
/// - the user no longer exists in the database, or
/// - the stored `key_version` does not match the DB value (passphrase changed on
/// another device since this session was created).
async fn require_valid_user(
pool: &sqlx::PgPool,
session: &Session,
context: &str,
) -> Result<secrets_core::models::User, Response> {
let Some(user_id) = current_user_id(session).await else {
return Err(Redirect::to("/login").into_response());
};
let user = match secrets_core::service::user::get_user_by_id(pool, user_id).await {
Err(e) => {
tracing::error!(error = %e, %user_id, context, "failed to load user");
return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response());
}
Ok(None) => {
if let Err(e) = session.flush().await {
tracing::warn!(error = %e, "failed to flush stale session");
}
return Err(Redirect::to("/login").into_response());
}
Ok(Some(u)) => u,
};
let session_kv: Option<i64> = match session.get::<i64>(SESSION_KEY_VERSION).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "failed to read key_version from session; treating as missing");
None
}
};
if let Some(kv) = session_kv
&& kv != user.key_version
{
tracing::info!(%user_id, session_kv = kv, db_kv = user.key_version, "key_version mismatch; invalidating session");
if let Err(e) = session.flush().await {
tracing::warn!(error = %e, "failed to flush outdated session");
}
return Err(Redirect::to("/login").into_response());
}
Ok(user)
}
fn request_user_agent(headers: &HeaderMap) -> Option<String> {
headers
.get(header::USER_AGENT)
.and_then(|value| value.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(ToOwned::to_owned)
}
fn paginate(page: u32, total_count: i64, page_size: u32) -> (u32, u32, u32) {
let page_size = page_size.max(1);
let safe_total_count = u32::try_from(total_count.max(0)).unwrap_or(u32::MAX);
let total_pages = safe_total_count.div_ceil(page_size).max(1);
let current_page = page.max(1).min(total_pages);
let offset = (current_page - 1).saturating_mul(page_size);
(current_page, total_pages, offset)
}
fn render_template<T: Template>(tmpl: T) -> Result<Response, StatusCode> {
let html = tmpl.render().map_err(|e| {
tracing::error!(error = %e, "template render error");
StatusCode::INTERNAL_SERVER_ERROR
})?;
Ok(Html(html).into_response())
}
// ── Routes ────────────────────────────────────────────────────────────────────
pub fn web_router() -> Router<AppState> {
Router::new()
.route("/robots.txt", get(assets::robots_txt))
.route("/llms.txt", get(assets::llms_txt))
.route("/ai.txt", get(assets::ai_txt))
.route("/static/i18n.js", get(assets::i18n_js))
.route("/favicon.svg", get(assets::favicon_svg))
.route(
"/favicon.ico",
get(|| async { Redirect::permanent("/favicon.svg") }),
)
.route(
"/.well-known/oauth-protected-resource",
get(assets::oauth_protected_resource_metadata),
)
.route("/", get(auth::home_page))
.route("/login", get(auth::login_page))
.route("/auth/google", get(auth::auth_google))
.route("/auth/google/callback", get(auth::auth_google_callback))
.route("/auth/logout", post(auth::auth_logout))
.route("/dashboard", get(account::dashboard))
.route("/entries", get(entries::entries_page))
.route("/audit", get(audit::audit_page))
.route("/account/bind/google", get(auth::account_bind_google))
.route("/account/unbind/{provider}", post(auth::account_unbind))
.route("/api/key-salt", get(account::api_key_salt))
.route("/api/key-setup", post(account::api_key_setup))
.route("/api/key-change", post(account::api_key_change))
.route("/api/apikey", get(account::api_apikey_get))
.route(
"/api/apikey/regenerate",
post(account::api_apikey_regenerate),
)
.route(
"/api/entries/{id}",
patch(entries::api_entry_patch).delete(entries::api_entry_delete),
)
.route(
"/api/entries/{entry_id}/secrets/{secret_id}",
axum::routing::delete(entries::api_entry_secret_unlink),
)
.route(
"/api/entries/{id}/secrets/decrypt",
get(entries::api_entry_secrets_decrypt),
)
.route("/api/secrets/{secret_id}", patch(entries::api_secret_patch))
.route(
"/api/secrets/check-name",
get(entries::api_secret_check_name),
)
}
#[cfg(test)]
mod tests {
use std::net::SocketAddr;
use super::*;
#[test]
fn client_ip_ignores_forwarded_headers_without_trusted_proxy() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "203.0.113.10".parse().unwrap());
let ip = crate::client_ip::extract_client_ip_parts(
&headers,
SocketAddr::from(([127, 0, 0, 1], 9315)),
);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn client_ip_uses_valid_forwarded_header_with_trusted_proxy() {
// This test relies on TRUST_PROXY being unset (default); skip if set in env
if std::env::var("TRUST_PROXY").is_ok() {
return;
}
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "203.0.113.10, 10.0.0.1".parse().unwrap());
// Direct connection IP is used when TRUST_PROXY is not set
let ip = crate::client_ip::extract_client_ip_parts(
&headers,
SocketAddr::from(([127, 0, 0, 1], 9315)),
);
assert_eq!(ip, "127.0.0.1");
}
#[test]
fn request_ui_lang_prefers_zh_cn_over_en_fallback() {
let mut headers = HeaderMap::new();
headers.insert(header::ACCEPT_LANGUAGE, "zh-CN, en;q=0.5".parse().unwrap());
assert!(matches!(request_ui_lang(&headers), UiLang::ZhCn));
}
#[test]
fn request_ui_lang_detects_traditional_chinese_variants() {
let mut headers = HeaderMap::new();
headers.insert(
header::ACCEPT_LANGUAGE,
"zh-Hant, en;q=0.5".parse().unwrap(),
);
assert!(matches!(request_ui_lang(&headers), UiLang::ZhTw));
}
#[test]
fn paginate_clamps_page_before_computing_offset() {
let (current_page, total_pages, offset) = paginate(100, 12, 10);
assert_eq!(current_page, 2);
assert_eq!(total_pages, 2);
assert_eq!(offset, 10);
}
#[test]
fn paginate_handles_large_page_without_overflow() {
let (current_page, total_pages, offset) = paginate(u32::MAX, 1, ENTRIES_PAGE_LIMIT);
assert_eq!(current_page, 1);
assert_eq!(total_pages, 1);
assert_eq!(offset, 0);
}
#[test]
fn paginate_saturates_large_total_count() {
let (_, total_pages, _) = paginate(1, i64::MAX, ENTRIES_PAGE_LIMIT);
assert_eq!(total_pages, u32::MAX.div_ceil(ENTRIES_PAGE_LIMIT));
}
}