diff --git a/Cargo.lock b/Cargo.lock index a98fa31..3000cf6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2065,7 +2065,7 @@ dependencies = [ [[package]] name = "secrets-mcp" -version = "0.5.19" +version = "0.5.20" dependencies = [ "anyhow", "askama", diff --git a/crates/secrets-core/src/models.rs b/crates/secrets-core/src/models.rs index 8687d41..cab5e29 100644 --- a/crates/secrets-core/src/models.rs +++ b/crates/secrets-core/src/models.rs @@ -184,6 +184,9 @@ pub struct ExportEntry { /// Decrypted secret fields. None means no secrets in this export (--no-secrets). #[serde(default, skip_serializing_if = "Option::is_none")] pub secrets: Option>, + /// Per-secret types (`text`, `password`, `key`, …). Omitted in legacy exports; importers default to `"text"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub secret_types: Option>, } // ── Multi-user models ────────────────────────────────────────────────────────── @@ -311,3 +314,44 @@ pub fn toml_to_json_value(v: &toml::Value) -> Value { } } } + +#[cfg(test)] +mod export_entry_tests { + use super::*; + use std::collections::BTreeMap; + + #[test] + fn export_entry_roundtrip_includes_secret_types() { + let mut secrets = BTreeMap::new(); + secrets.insert("k".to_string(), serde_json::json!("v")); + let mut types = BTreeMap::new(); + types.insert("k".to_string(), "password".to_string()); + let e = ExportEntry { + name: "n".to_string(), + folder: "f".to_string(), + entry_type: "t".to_string(), + notes: "".to_string(), + tags: vec![], + metadata: serde_json::json!({}), + secrets: Some(secrets), + secret_types: Some(types), + }; + let json = serde_json::to_string(&e).unwrap(); + let back: ExportEntry = serde_json::from_str(&json).unwrap(); + assert_eq!( + back.secret_types + .as_ref() + .unwrap() + .get("k") + .map(String::as_str), + Some("password") + ); + } + + #[test] + fn export_entry_legacy_json_without_secret_types_deserializes() { + let json = r#"{"name":"a","folder":"","type":"","notes":"","tags":[],"metadata":{},"secrets":{"x":"y"}}"#; + let e: ExportEntry = serde_json::from_str(json).unwrap(); + assert!(e.secret_types.is_none()); + } +} diff --git a/crates/secrets-core/src/service/add.rs b/crates/secrets-core/src/service/add.rs index 751c653..6faa0f9 100644 --- a/crates/secrets-core/src/service/add.rs +++ b/crates/secrets-core/src/service/add.rs @@ -161,6 +161,7 @@ pub fn flatten_json_fields(prefix: &str, value: &Value) -> Vec<(String, Value)> #[derive(Debug, serde::Serialize)] pub struct AddResult { + pub entry_id: Uuid, pub name: String, pub folder: String, #[serde(rename = "type")] @@ -477,6 +478,7 @@ pub async fn run(pool: &PgPool, params: AddParams<'_>, master_key: &[u8; 32]) -> tx.commit().await?; Ok(AddResult { + entry_id, name: params.name.to_string(), folder: params.folder.to_string(), entry_type: entry_type.to_string(), diff --git a/crates/secrets-core/src/service/api_key.rs b/crates/secrets-core/src/service/api_key.rs index 6bd25e9..1d7d635 100644 --- a/crates/secrets-core/src/service/api_key.rs +++ b/crates/secrets-core/src/service/api_key.rs @@ -47,11 +47,14 @@ pub async fn ensure_api_key(pool: &PgPool, user_id: Uuid) -> Result { /// Generate a fresh API key for the user, replacing the old one. pub async fn regenerate_api_key(pool: &PgPool, user_id: Uuid) -> Result { let new_key = generate_api_key(); - sqlx::query("UPDATE users SET api_key = $1 WHERE id = $2") + let res = sqlx::query("UPDATE users SET api_key = $1 WHERE id = $2") .bind(&new_key) .bind(user_id) .execute(pool) .await?; + if res.rows_affected() == 0 { + return Err(AppError::NotFoundUser.into()); + } Ok(new_key) } diff --git a/crates/secrets-core/src/service/env_map.rs b/crates/secrets-core/src/service/env_map.rs index 04f3030..f8d9399 100644 --- a/crates/secrets-core/src/service/env_map.rs +++ b/crates/secrets-core/src/service/env_map.rs @@ -45,18 +45,27 @@ pub async fn build_env_map( for f in fields { let decrypted = crypto::decrypt_json(master_key, &f.encrypted)?; - let key = format!( - "{}_{}", - effective_prefix, - f.name.to_uppercase().replace(['-', '.'], "_") - ); - combined.insert(key, json_to_env_string(&decrypted)); + let seg = secret_name_to_env_segment(&f.name); + let key = format!("{}_{}", effective_prefix, seg); + if let Some(_old) = combined.insert(key.clone(), json_to_env_string(&decrypted)) { + anyhow::bail!( + "environment variable name collision after normalization: '{}' (secret '{}')", + key, + f.name + ); + } } } Ok(combined) } +/// Map a secret field name to an env key segment: `.` → `__`, `-` → `_`, then uppercase. +/// Avoids collisions between e.g. `db.password` and `db_password`. +fn secret_name_to_env_segment(name: &str) -> String { + name.replace('.', "__").replace('-', "_").to_uppercase() +} + fn env_prefix(entry: &crate::models::Entry, prefix: &str) -> String { let name_part = entry.name.to_uppercase().replace(['-', '.', ' '], "_"); if prefix.is_empty() { @@ -75,3 +84,14 @@ fn json_to_env_string(v: &Value) -> String { other => other.to_string(), } } + +#[cfg(test)] +mod tests { + use super::secret_name_to_env_segment; + + #[test] + fn secret_name_env_segment_disambiguates_dot_from_underscore() { + assert_eq!(secret_name_to_env_segment("db.password"), "DB__PASSWORD"); + assert_eq!(secret_name_to_env_segment("db_password"), "DB_PASSWORD"); + } +} diff --git a/crates/secrets-core/src/service/export.rs b/crates/secrets-core/src/service/export.rs index ec6dd14..463f15b 100644 --- a/crates/secrets-core/src/service/export.rs +++ b/crates/secrets-core/src/service/export.rs @@ -44,21 +44,23 @@ pub async fn export( let mut export_entries: Vec = Vec::with_capacity(entries.len()); for entry in &entries { - let secrets = if params.no_secrets { - None + let (secrets, secret_types) = if params.no_secrets { + (None, None) } else { let fields = secrets_map.get(&entry.id).map(Vec::as_slice).unwrap_or(&[]); if fields.is_empty() { - Some(BTreeMap::new()) + (Some(BTreeMap::new()), Some(BTreeMap::new())) } else { let mk = master_key .ok_or_else(|| anyhow::anyhow!("master key required to decrypt secrets"))?; let mut map = BTreeMap::new(); + let mut type_map = BTreeMap::new(); for f in fields { let decrypted = crypto::decrypt_json(mk, &f.encrypted)?; map.insert(f.name.clone(), decrypted); + type_map.insert(f.name.clone(), f.secret_type.clone()); } - Some(map) + (Some(map), Some(type_map)) } }; @@ -70,6 +72,7 @@ pub async fn export( tags: entry.tags.clone(), metadata: entry.metadata.clone(), secrets, + secret_types, }); } diff --git a/crates/secrets-core/src/service/import.rs b/crates/secrets-core/src/service/import.rs index bc4b624..83b8e1f 100644 --- a/crates/secrets-core/src/service/import.rs +++ b/crates/secrets-core/src/service/import.rs @@ -1,5 +1,6 @@ use anyhow::Result; use sqlx::PgPool; +use std::collections::HashMap; use uuid::Uuid; use crate::models::ExportFormat; @@ -80,6 +81,11 @@ pub async fn run( let secret_entries = build_secret_entries(entry.secrets.as_ref()); let meta_entries = build_meta_entries(&entry.metadata); + let secret_types_map: HashMap = entry + .secret_types + .as_ref() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) + .unwrap_or_default(); match add_run( pool, @@ -91,7 +97,7 @@ pub async fn run( tags: &entry.tags, meta_entries: &meta_entries, secret_entries: &secret_entries, - secret_types: &Default::default(), + secret_types: &secret_types_map, link_secret_names: &[], user_id: params.user_id, }, diff --git a/crates/secrets-core/src/service/rollback.rs b/crates/secrets-core/src/service/rollback.rs index 084d5ea..33168b6 100644 --- a/crates/secrets-core/src/service/rollback.rs +++ b/crates/secrets-core/src/service/rollback.rs @@ -30,58 +30,61 @@ pub async fn run( folder: String, #[sqlx(rename = "type")] entry_type: String, + name: String, version: i64, action: String, tags: Vec, metadata: Value, } - let live_entry: Option = if let Some(uid) = user_id { + let mut tx = pool.begin().await?; + + let live: Option = if let Some(uid) = user_id { sqlx::query_as( "SELECT id, version, folder, type, name, tags, metadata, notes, deleted_at FROM entries \ - WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL", + WHERE id = $1 AND user_id = $2 AND deleted_at IS NULL FOR UPDATE", ) .bind(entry_id) .bind(uid) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await? } else { sqlx::query_as( "SELECT id, version, folder, type, name, tags, metadata, notes, deleted_at FROM entries \ - WHERE id = $1 AND user_id IS NULL AND deleted_at IS NULL", + WHERE id = $1 AND user_id IS NULL AND deleted_at IS NULL FOR UPDATE", ) .bind(entry_id) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await? }; - let live_entry = live_entry.ok_or(AppError::NotFoundEntry)?; + let lr = live.ok_or(AppError::NotFoundEntry)?; let snap: Option = if let Some(ver) = to_version { sqlx::query_as( - "SELECT folder, type, version, action, tags, metadata \ + "SELECT folder, type, name, version, action, tags, metadata \ FROM entries_history \ WHERE entry_id = $1 AND version = $2 ORDER BY id ASC LIMIT 1", ) .bind(entry_id) .bind(ver) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await? } else { sqlx::query_as( - "SELECT folder, type, version, action, tags, metadata \ + "SELECT folder, type, name, version, action, tags, metadata \ FROM entries_history \ WHERE entry_id = $1 ORDER BY id DESC LIMIT 1", ) .bind(entry_id) - .fetch_optional(pool) + .fetch_optional(&mut *tx) .await? }; let snap = snap.ok_or_else(|| { anyhow::anyhow!( "No history found for entry '{}'{}.", - live_entry.name, + lr.name, to_version .map(|v| format!(" at version {}", v)) .unwrap_or_default() @@ -91,17 +94,7 @@ pub async fn run( let snap_secret_snapshot = db::entry_secret_snapshot_from_metadata(&snap.metadata); let snap_metadata = db::strip_secret_snapshot_from_metadata(&snap.metadata); - let mut tx = pool.begin().await?; - - let live: Option = sqlx::query_as( - "SELECT id, version, folder, type, name, tags, metadata, notes, deleted_at FROM entries \ - WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", - ) - .bind(entry_id) - .fetch_optional(&mut *tx) - .await?; - - let live_entry_id = if let Some(ref lr) = live { + let live_entry_id = { let history_metadata = match db::metadata_with_secret_snapshot(&mut tx, lr.id, &lr.metadata).await { Ok(v) => v, @@ -168,8 +161,8 @@ pub async fn run( ) .bind(&snap.folder) .bind(&snap.entry_type) - .bind(&live_entry.name) - .bind(&live_entry.notes) + .bind(&snap.name) + .bind(&lr.notes) .bind(&snap.tags) .bind(&snap_metadata) .bind(lr.id) @@ -177,8 +170,6 @@ pub async fn run( .await?; lr.id - } else { - return Err(AppError::NotFoundEntry.into()); }; if let Some(secret_snapshot) = snap_secret_snapshot { @@ -191,7 +182,7 @@ pub async fn run( "rollback", &snap.folder, &snap.entry_type, - &live_entry.name, + &snap.name, serde_json::json!({ "entry_id": entry_id, "restored_version": snap.version, @@ -203,7 +194,7 @@ pub async fn run( tx.commit().await?; Ok(RollbackResult { - name: live_entry.name, + name: snap.name, folder: snap.folder, entry_type: snap.entry_type, restored_version: snap.version, diff --git a/crates/secrets-mcp/Cargo.toml b/crates/secrets-mcp/Cargo.toml index a247dc7..ee90cd8 100644 --- a/crates/secrets-mcp/Cargo.toml +++ b/crates/secrets-mcp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "secrets-mcp" -version = "0.5.19" +version = "0.5.20" edition.workspace = true [[bin]] diff --git a/crates/secrets-mcp/src/tools.rs b/crates/secrets-mcp/src/tools.rs index 84e7cd3..b8c8327 100644 --- a/crates/secrets-mcp/src/tools.rs +++ b/crates/secrets-mcp/src/tools.rs @@ -345,15 +345,6 @@ impl SecretsService { Self::extract_enc_key(ctx) } - /// Require both user_id and encryption key (header only, no arg fallback). - fn require_user_and_key( - ctx: &RequestContext, - ) -> Result<(Uuid, [u8; 32]), rmcp::ErrorData> { - let user_id = Self::require_user_id(ctx)?; - let key = Self::extract_enc_key(ctx)?; - Ok((user_id, key)) - } - /// Require both user_id and encryption key, preferring an explicit argument /// value over the X-Encryption-Key header. fn require_user_and_key_or_arg( @@ -801,10 +792,7 @@ impl SecretsService { let total_count = secrets_core::service::search::count_entries(&self.pool, &count_params) .await - .inspect_err( - |e| tracing::warn!(tool = "secrets_find", error = %e, "count_entries failed"), - ) - .unwrap_or(0); + .map_err(|e| mcp_err_internal_logged("secrets_find", Some(user_id), e))?; let relation_map = get_relations_for_entries( &self.pool, &result @@ -1135,11 +1123,8 @@ impl SecretsService { .await .map_err(|e| mcp_err_from_anyhow("secrets_add", Some(user_id), e))?; - let created_entry = resolve_entry(&self.pool, &input.name, Some(folder), Some(user_id)) - .await - .map_err(|e| mcp_err_internal_logged("secrets_add", Some(user_id), e))?; for parent_id in parent_ids { - add_parent_relation(&self.pool, parent_id, created_entry.id, Some(user_id)) + add_parent_relation(&self.pool, parent_id, result.entry_id, Some(user_id)) .await .map_err(|e| mcp_err_from_anyhow("secrets_add", Some(user_id), e))?; } @@ -1420,7 +1405,7 @@ impl SecretsService { } #[tool( - description = "Rollback an entry to a previous version. Requires X-Encryption-Key header. \ + description = "Rollback an entry to a previous version. Requires Bearer API key only (no encryption key). \ Omit to_version to restore the most recent snapshot. \ Optionally pass 'id' (from secrets_find) to target directly.", annotations(title = "Rollback Secret Entry", destructive_hint = true) @@ -1431,7 +1416,7 @@ impl SecretsService { ctx: RequestContext, ) -> Result { let t = Instant::now(); - let (user_id, _user_key) = Self::require_user_and_key(&ctx)?; + let user_id = Self::require_user_id(&ctx)?; tracing::info!( tool = "secrets_rollback", ?user_id, diff --git a/crates/secrets-mcp/src/web/account.rs b/crates/secrets-mcp/src/web/account.rs index ab33f02..a7106fb 100644 --- a/crates/secrets-mcp/src/web/account.rs +++ b/crates/secrets-mcp/src/web/account.rs @@ -11,7 +11,7 @@ use secrets_core::service::{ use crate::AppState; -use super::{SESSION_KEY_VERSION, current_user_id, render_template, require_valid_user}; +use super::{SESSION_KEY_VERSION, load_session_user_strict, render_template, require_valid_user}; #[derive(Template)] #[template(path = "dashboard.html")] @@ -92,17 +92,11 @@ pub(super) async fn api_key_salt( State(state): State, session: Session, ) -> Result, StatusCode> { - let user_id = current_user_id(&session) - .await - .ok_or(StatusCode::UNAUTHORIZED)?; - - let user = get_user_by_id(&state.pool, user_id) - .await - .map_err(|e| { - tracing::error!(error = %e, %user_id, "failed to load user for key-salt API"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .ok_or(StatusCode::UNAUTHORIZED)?; + let user = match load_session_user_strict(&state.pool, &session).await { + Ok(Some(u)) => u, + Ok(None) => return Err(StatusCode::UNAUTHORIZED), + Err(()) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; if user.key_salt.is_none() { return Ok(Json(KeySaltResponse { @@ -126,19 +120,14 @@ pub(super) async fn api_key_setup( session: Session, Json(body): Json, ) -> Result, StatusCode> { - let user_id = current_user_id(&session) - .await - .ok_or(StatusCode::UNAUTHORIZED)?; + let user = match load_session_user_strict(&state.pool, &session).await { + Ok(Some(u)) => u, + Ok(None) => return Err(StatusCode::UNAUTHORIZED), + Err(()) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + let user_id = user.id; // Guard: if a passphrase is already configured, reject and direct to /api/key-change - let user = get_user_by_id(&state.pool, user_id) - .await - .map_err(|e| { - tracing::error!(error = %e, %user_id, "failed to load user for key-setup guard"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .ok_or(StatusCode::UNAUTHORIZED)?; - if user.key_salt.is_some() { tracing::warn!(%user_id, "key-setup called but passphrase already configured; use /api/key-change"); return Err(StatusCode::CONFLICT); @@ -175,17 +164,12 @@ pub(super) async fn api_key_change( session: Session, Json(body): Json, ) -> Result, StatusCode> { - let user_id = current_user_id(&session) - .await - .ok_or(StatusCode::UNAUTHORIZED)?; - - let user = get_user_by_id(&state.pool, user_id) - .await - .map_err(|e| { - tracing::error!(error = %e, %user_id, "failed to load user for key-change"); - StatusCode::INTERNAL_SERVER_ERROR - })? - .ok_or(StatusCode::UNAUTHORIZED)?; + let user = match load_session_user_strict(&state.pool, &session).await { + Ok(Some(u)) => u, + Ok(None) => return Err(StatusCode::UNAUTHORIZED), + Err(()) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + let user_id = user.id; // Must have an existing passphrase to change let existing_key_check = user.key_check.ok_or_else(|| { @@ -276,9 +260,12 @@ pub(super) async fn api_apikey_get( State(state): State, session: Session, ) -> Result, StatusCode> { - let user_id = current_user_id(&session) - .await - .ok_or(StatusCode::UNAUTHORIZED)?; + let user = match load_session_user_strict(&state.pool, &session).await { + Ok(Some(u)) => u, + Ok(None) => return Err(StatusCode::UNAUTHORIZED), + Err(()) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + let user_id = user.id; let api_key = ensure_api_key(&state.pool, user_id).await.map_err(|e| { tracing::error!(error = %e, %user_id, "ensure_api_key failed"); @@ -292,9 +279,12 @@ pub(super) async fn api_apikey_regenerate( State(state): State, session: Session, ) -> Result, StatusCode> { - let user_id = current_user_id(&session) - .await - .ok_or(StatusCode::UNAUTHORIZED)?; + let user = match load_session_user_strict(&state.pool, &session).await { + Ok(Some(u)) => u, + Ok(None) => return Err(StatusCode::UNAUTHORIZED), + Err(()) => return Err(StatusCode::INTERNAL_SERVER_ERROR), + }; + let user_id = user.id; let api_key = regenerate_api_key(&state.pool, user_id) .await diff --git a/crates/secrets-mcp/src/web/entries.rs b/crates/secrets-mcp/src/web/entries.rs index ec5c247..f1654c7 100644 --- a/crates/secrets-mcp/src/web/entries.rs +++ b/crates/secrets-mcp/src/web/entries.rs @@ -25,8 +25,8 @@ use secrets_core::service::{ use crate::AppState; use super::{ - ENTRIES_PAGE_LIMIT, UiLang, current_user_id, paginate, render_template, request_ui_lang, - require_valid_user, tr, + ENTRIES_PAGE_LIMIT, UiLang, paginate, render_template, request_ui_lang, require_valid_user, + require_valid_user_json, tr, }; // ── Template types ──────────────────────────────────────────────────────────── @@ -616,10 +616,8 @@ pub(super) async fn api_entry_patch( Json(body): Json, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let folder = body.folder.trim(); let entry_type = body.entry_type.trim(); @@ -635,6 +633,39 @@ pub(super) async fn api_entry_patch( )); } + if folder.chars().count() > crate::validation::MAX_FOLDER_LENGTH { + return Err(( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": tr(lang, "folder 长度不能超过 128 个字符", "folder 長度不能超過 128 個字元", "folder must be at most 128 characters") }), + ), + )); + } + if entry_type.chars().count() > crate::validation::MAX_ENTRY_TYPE_LENGTH { + return Err(( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": tr(lang, "type 长度不能超过 64 个字符", "type 長度不能超過 64 個字元", "type must be at most 64 characters") }), + ), + )); + } + if name.chars().count() > crate::validation::MAX_NAME_LENGTH { + return Err(( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": tr(lang, "name 长度不能超过 256 个字符", "name 長度不能超過 256 個字元", "name must be at most 256 characters") }), + ), + )); + } + if notes.chars().count() > crate::validation::MAX_NOTES_LENGTH { + return Err(( + StatusCode::BAD_REQUEST, + Json( + json!({ "error": tr(lang, "notes 长度不能超过 10000 个字符", "notes 長度不能超過 10000 個字元", "notes must be at most 10000 characters") }), + ), + )); + } + let tags: Vec = body .tags .into_iter() @@ -683,10 +714,8 @@ pub(super) async fn api_entry_options( Query(q): Query, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let query = q.q.as_deref() @@ -738,10 +767,8 @@ pub(super) async fn api_entry_delete( Path(entry_id): Path, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; delete_by_id(&state.pool, entry_id, user_id) .await @@ -760,10 +787,8 @@ pub(super) async fn api_trash_restore( Path(entry_id): Path, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; restore_deleted_by_id(&state.pool, entry_id, user_id) .await @@ -782,10 +807,8 @@ pub(super) async fn api_trash_purge( Path(entry_id): Path, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; purge_deleted_by_id(&state.pool, entry_id, user_id) .await @@ -818,10 +841,8 @@ pub(super) async fn api_secret_check_name( Query(params): Query, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let name = params.name.trim(); if name.is_empty() { @@ -914,10 +935,8 @@ pub(super) async fn api_secret_patch( } let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let name = body.name.as_ref().map(|s| s.trim()); let secret_type = body.secret_type.as_ref().map(|s| s.trim()); @@ -1123,10 +1142,8 @@ pub(super) async fn api_entry_secret_unlink( } let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let mut tx = state .pool @@ -1216,10 +1233,8 @@ pub(super) async fn api_entry_secrets_decrypt( Path(entry_id): Path, ) -> Result, EntryApiError> { let lang = request_ui_lang(&headers); - let user_id = current_user_id(&session).await.ok_or(( - StatusCode::UNAUTHORIZED, - Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), - ))?; + let user = require_valid_user_json(&state.pool, &session, lang).await?; + let user_id = user.id; let master_key = require_encryption_key(&headers, lang)?; diff --git a/crates/secrets-mcp/src/web/mod.rs b/crates/secrets-mcp/src/web/mod.rs index 3bd6eca..2e48227 100644 --- a/crates/secrets-mcp/src/web/mod.rs +++ b/crates/secrets-mcp/src/web/mod.rs @@ -1,10 +1,11 @@ use askama::Template; use axum::{ - Router, + Json, Router, http::{HeaderMap, StatusCode, header}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, patch, post}, }; +use serde_json::json; use tower_sessions::Session; use uuid::Uuid; @@ -34,7 +35,7 @@ const AUDIT_PAGE_LIMIT: i64 = 10; // ── UI language ─────────────────────────────────────────────────────────────── #[derive(Clone, Copy)] -enum UiLang { +pub(super) enum UiLang { ZhCn, ZhTw, En, @@ -143,6 +144,71 @@ async fn require_valid_user( Ok(user) } +/// `Ok(None)` — unauthenticated or session invalidated (including `key_version` mismatch). +/// `Err(())` — database error loading the user. +pub(super) async fn load_session_user_strict( + pool: &sqlx::PgPool, + session: &Session, +) -> Result, ()> { + let Some(user_id) = current_user_id(session).await else { + return Ok(None); + }; + + let user = match secrets_core::service::user::get_user_by_id(pool, user_id).await { + Err(e) => { + tracing::error!(error = %e, %user_id, "load_session_user_strict: failed to load user"); + return Err(()); + } + Ok(None) => { + if let Err(e) = session.flush().await { + tracing::warn!(error = %e, "failed to flush stale session"); + } + return Ok(None); + } + Ok(Some(u)) => u, + }; + + let session_kv: Option = match session.get::(SESSION_KEY_VERSION).await { + Ok(v) => v, + Err(e) => { + tracing::warn!(error = %e, "failed to read key_version from session; treating as missing"); + None + } + }; + if let Some(kv) = session_kv + && kv != user.key_version + { + tracing::info!(%user_id, session_kv = kv, db_kv = user.key_version, "key_version mismatch; invalidating session (API)"); + if let Err(e) = session.flush().await { + tracing::warn!(error = %e, "failed to flush outdated session"); + } + return Ok(None); + } + + Ok(Some(user)) +} + +/// JSON API equivalent of [`require_valid_user`]: returns `401` with a JSON body instead of redirecting. +pub(super) async fn require_valid_user_json( + pool: &sqlx::PgPool, + session: &Session, + lang: UiLang, +) -> Result)> { + match load_session_user_strict(pool, session).await { + Ok(Some(user)) => Ok(user), + Ok(None) => Err(( + StatusCode::UNAUTHORIZED, + Json(json!({ "error": tr(lang, "未登录", "尚未登入", "Not logged in") })), + )), + Err(()) => Err(( + StatusCode::INTERNAL_SERVER_ERROR, + Json( + json!({ "error": tr(lang, "操作失败,请稍后重试", "操作失敗,請稍後重試", "Operation failed, please try again later") }), + ), + )), + } +} + fn request_user_agent(headers: &HeaderMap) -> Option { headers .get(header::USER_AGENT)