- 拆分 web.rs 为 web/ 子模块;统一 client_ip 提取 - core: user_scope SQL 复用、env_map N+1 消除、FETCH_ALL 上限调整 - entries 列表页并行查询;PgPool 去 Arc;结构化 NotFound 等错误 - CI: SSH 私钥安全写入;crypto/hex 与依赖清理;MCP 输入长度校验 - AGENTS: API Key 明文存储设计说明
307 lines
10 KiB
Rust
307 lines
10 KiB
Rust
use askama::Template;
|
|
use axum::{
|
|
Router,
|
|
http::{HeaderMap, StatusCode, header},
|
|
response::{Html, IntoResponse, Redirect, Response},
|
|
routing::{get, patch, post},
|
|
};
|
|
use tower_sessions::Session;
|
|
use uuid::Uuid;
|
|
|
|
use crate::AppState;
|
|
use crate::oauth::OAuthConfig;
|
|
|
|
mod account;
|
|
mod assets;
|
|
mod audit;
|
|
mod auth;
|
|
mod entries;
|
|
|
|
// ── Session keys ──────────────────────────────────────────────────────────────
|
|
|
|
const SESSION_USER_ID: &str = "user_id";
|
|
const SESSION_OAUTH_STATE: &str = "oauth_state";
|
|
const SESSION_OAUTH_BIND_MODE: &str = "oauth_bind_mode";
|
|
const SESSION_LOGIN_PROVIDER: &str = "login_provider";
|
|
const SESSION_KEY_VERSION: &str = "key_version";
|
|
|
|
// ── Page limits ───────────────────────────────────────────────────────────────
|
|
|
|
/// Cap for HTML list (avoids loading unbounded rows into memory).
|
|
const ENTRIES_PAGE_LIMIT: u32 = 50;
|
|
const AUDIT_PAGE_LIMIT: i64 = 10;
|
|
|
|
// ── UI language ───────────────────────────────────────────────────────────────
|
|
|
|
#[derive(Clone, Copy)]
|
|
enum UiLang {
|
|
ZhCn,
|
|
ZhTw,
|
|
En,
|
|
}
|
|
|
|
fn request_ui_lang(headers: &HeaderMap) -> UiLang {
|
|
let Some(raw) = headers
|
|
.get(header::ACCEPT_LANGUAGE)
|
|
.and_then(|v| v.to_str().ok())
|
|
else {
|
|
return UiLang::ZhCn;
|
|
};
|
|
let lower = raw.to_ascii_lowercase();
|
|
if lower.contains("zh-tw") || lower.contains("zh-hk") || lower.contains("zh-hant") {
|
|
UiLang::ZhTw
|
|
} else if lower.contains("zh") {
|
|
UiLang::ZhCn
|
|
} else if lower.contains("en") {
|
|
UiLang::En
|
|
} else {
|
|
UiLang::ZhCn
|
|
}
|
|
}
|
|
|
|
fn tr(lang: UiLang, zh_cn: &'static str, zh_tw: &'static str, en: &'static str) -> &'static str {
|
|
match lang {
|
|
UiLang::ZhCn => zh_cn,
|
|
UiLang::ZhTw => zh_tw,
|
|
UiLang::En => en,
|
|
}
|
|
}
|
|
|
|
// ── App state helpers ─────────────────────────────────────────────────────────
|
|
|
|
fn google_cfg(state: &AppState) -> Option<&OAuthConfig> {
|
|
state.google_config.as_ref()
|
|
}
|
|
|
|
async fn current_user_id(session: &Session) -> Option<Uuid> {
|
|
match session.get::<String>(SESSION_USER_ID).await {
|
|
Ok(opt) => match opt {
|
|
Some(s) => match Uuid::parse_str(&s) {
|
|
Ok(id) => Some(id),
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, user_id_str = %s, "invalid user_id UUID in session");
|
|
None
|
|
}
|
|
},
|
|
None => None,
|
|
},
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "failed to read user_id from session");
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Load and validate the current user from session and DB.
|
|
///
|
|
/// Returns the user if the session is valid. Flushes the session and returns
|
|
/// `Err(Redirect::to("/login"))` when:
|
|
/// - the session has no `user_id`,
|
|
/// - the user no longer exists in the database, or
|
|
/// - the stored `key_version` does not match the DB value (passphrase changed on
|
|
/// another device since this session was created).
|
|
async fn require_valid_user(
|
|
pool: &sqlx::PgPool,
|
|
session: &Session,
|
|
context: &str,
|
|
) -> Result<secrets_core::models::User, Response> {
|
|
let Some(user_id) = current_user_id(session).await else {
|
|
return Err(Redirect::to("/login").into_response());
|
|
};
|
|
|
|
let user = match secrets_core::service::user::get_user_by_id(pool, user_id).await {
|
|
Err(e) => {
|
|
tracing::error!(error = %e, %user_id, context, "failed to load user");
|
|
return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response());
|
|
}
|
|
Ok(None) => {
|
|
if let Err(e) = session.flush().await {
|
|
tracing::warn!(error = %e, "failed to flush stale session");
|
|
}
|
|
return Err(Redirect::to("/login").into_response());
|
|
}
|
|
Ok(Some(u)) => u,
|
|
};
|
|
|
|
let session_kv: Option<i64> = match session.get::<i64>(SESSION_KEY_VERSION).await {
|
|
Ok(v) => v,
|
|
Err(e) => {
|
|
tracing::warn!(error = %e, "failed to read key_version from session; treating as missing");
|
|
None
|
|
}
|
|
};
|
|
if let Some(kv) = session_kv
|
|
&& kv != user.key_version
|
|
{
|
|
tracing::info!(%user_id, session_kv = kv, db_kv = user.key_version, "key_version mismatch; invalidating session");
|
|
if let Err(e) = session.flush().await {
|
|
tracing::warn!(error = %e, "failed to flush outdated session");
|
|
}
|
|
return Err(Redirect::to("/login").into_response());
|
|
}
|
|
|
|
Ok(user)
|
|
}
|
|
|
|
fn request_user_agent(headers: &HeaderMap) -> Option<String> {
|
|
headers
|
|
.get(header::USER_AGENT)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.map(ToOwned::to_owned)
|
|
}
|
|
|
|
fn paginate(page: u32, total_count: i64, page_size: u32) -> (u32, u32, u32) {
|
|
let page_size = page_size.max(1);
|
|
let safe_total_count = u32::try_from(total_count.max(0)).unwrap_or(u32::MAX);
|
|
let total_pages = safe_total_count.div_ceil(page_size).max(1);
|
|
let current_page = page.max(1).min(total_pages);
|
|
let offset = (current_page - 1).saturating_mul(page_size);
|
|
(current_page, total_pages, offset)
|
|
}
|
|
|
|
fn render_template<T: Template>(tmpl: T) -> Result<Response, StatusCode> {
|
|
let html = tmpl.render().map_err(|e| {
|
|
tracing::error!(error = %e, "template render error");
|
|
StatusCode::INTERNAL_SERVER_ERROR
|
|
})?;
|
|
Ok(Html(html).into_response())
|
|
}
|
|
|
|
// ── Routes ────────────────────────────────────────────────────────────────────
|
|
|
|
pub fn web_router() -> Router<AppState> {
|
|
Router::new()
|
|
.route("/robots.txt", get(assets::robots_txt))
|
|
.route("/llms.txt", get(assets::llms_txt))
|
|
.route("/ai.txt", get(assets::ai_txt))
|
|
.route("/static/i18n.js", get(assets::i18n_js))
|
|
.route("/favicon.svg", get(assets::favicon_svg))
|
|
.route(
|
|
"/favicon.ico",
|
|
get(|| async { Redirect::permanent("/favicon.svg") }),
|
|
)
|
|
.route(
|
|
"/.well-known/oauth-protected-resource",
|
|
get(assets::oauth_protected_resource_metadata),
|
|
)
|
|
.route("/", get(auth::home_page))
|
|
.route("/login", get(auth::login_page))
|
|
.route("/auth/google", get(auth::auth_google))
|
|
.route("/auth/google/callback", get(auth::auth_google_callback))
|
|
.route("/auth/logout", post(auth::auth_logout))
|
|
.route("/dashboard", get(account::dashboard))
|
|
.route("/entries", get(entries::entries_page))
|
|
.route("/audit", get(audit::audit_page))
|
|
.route("/account/bind/google", get(auth::account_bind_google))
|
|
.route("/account/unbind/{provider}", post(auth::account_unbind))
|
|
.route("/api/key-salt", get(account::api_key_salt))
|
|
.route("/api/key-setup", post(account::api_key_setup))
|
|
.route("/api/key-change", post(account::api_key_change))
|
|
.route("/api/apikey", get(account::api_apikey_get))
|
|
.route(
|
|
"/api/apikey/regenerate",
|
|
post(account::api_apikey_regenerate),
|
|
)
|
|
.route(
|
|
"/api/entries/{id}",
|
|
patch(entries::api_entry_patch).delete(entries::api_entry_delete),
|
|
)
|
|
.route(
|
|
"/api/entries/{entry_id}/secrets/{secret_id}",
|
|
axum::routing::delete(entries::api_entry_secret_unlink),
|
|
)
|
|
.route(
|
|
"/api/entries/{id}/secrets/decrypt",
|
|
get(entries::api_entry_secrets_decrypt),
|
|
)
|
|
.route("/api/secrets/{secret_id}", patch(entries::api_secret_patch))
|
|
.route(
|
|
"/api/secrets/check-name",
|
|
get(entries::api_secret_check_name),
|
|
)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::net::SocketAddr;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn client_ip_ignores_forwarded_headers_without_trusted_proxy() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert("x-forwarded-for", "203.0.113.10".parse().unwrap());
|
|
|
|
let ip = crate::client_ip::extract_client_ip_parts(
|
|
&headers,
|
|
SocketAddr::from(([127, 0, 0, 1], 9315)),
|
|
);
|
|
|
|
assert_eq!(ip, "127.0.0.1");
|
|
}
|
|
|
|
#[test]
|
|
fn client_ip_uses_valid_forwarded_header_with_trusted_proxy() {
|
|
// This test relies on TRUST_PROXY being unset (default); skip if set in env
|
|
if std::env::var("TRUST_PROXY").is_ok() {
|
|
return;
|
|
}
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert("x-forwarded-for", "203.0.113.10, 10.0.0.1".parse().unwrap());
|
|
|
|
// Direct connection IP is used when TRUST_PROXY is not set
|
|
let ip = crate::client_ip::extract_client_ip_parts(
|
|
&headers,
|
|
SocketAddr::from(([127, 0, 0, 1], 9315)),
|
|
);
|
|
|
|
assert_eq!(ip, "127.0.0.1");
|
|
}
|
|
|
|
#[test]
|
|
fn request_ui_lang_prefers_zh_cn_over_en_fallback() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(header::ACCEPT_LANGUAGE, "zh-CN, en;q=0.5".parse().unwrap());
|
|
|
|
assert!(matches!(request_ui_lang(&headers), UiLang::ZhCn));
|
|
}
|
|
|
|
#[test]
|
|
fn request_ui_lang_detects_traditional_chinese_variants() {
|
|
let mut headers = HeaderMap::new();
|
|
headers.insert(
|
|
header::ACCEPT_LANGUAGE,
|
|
"zh-Hant, en;q=0.5".parse().unwrap(),
|
|
);
|
|
|
|
assert!(matches!(request_ui_lang(&headers), UiLang::ZhTw));
|
|
}
|
|
|
|
#[test]
|
|
fn paginate_clamps_page_before_computing_offset() {
|
|
let (current_page, total_pages, offset) = paginate(100, 12, 10);
|
|
|
|
assert_eq!(current_page, 2);
|
|
assert_eq!(total_pages, 2);
|
|
assert_eq!(offset, 10);
|
|
}
|
|
|
|
#[test]
|
|
fn paginate_handles_large_page_without_overflow() {
|
|
let (current_page, total_pages, offset) = paginate(u32::MAX, 1, ENTRIES_PAGE_LIMIT);
|
|
|
|
assert_eq!(current_page, 1);
|
|
assert_eq!(total_pages, 1);
|
|
assert_eq!(offset, 0);
|
|
}
|
|
|
|
#[test]
|
|
fn paginate_saturates_large_total_count() {
|
|
let (_, total_pages, _) = paginate(1, i64::MAX, ENTRIES_PAGE_LIMIT);
|
|
|
|
assert_eq!(total_pages, u32::MAX.div_ceil(ENTRIES_PAGE_LIMIT));
|
|
}
|
|
}
|