use askama::Template; use axum::{ Router, http::{HeaderMap, StatusCode, header}, response::{Html, IntoResponse, Redirect, Response}, routing::{get, patch, post}, }; use tower_sessions::Session; use uuid::Uuid; use crate::AppState; use crate::oauth::OAuthConfig; mod account; mod assets; mod audit; mod auth; mod entries; // ── Session keys ────────────────────────────────────────────────────────────── const SESSION_USER_ID: &str = "user_id"; const SESSION_OAUTH_STATE: &str = "oauth_state"; const SESSION_OAUTH_BIND_MODE: &str = "oauth_bind_mode"; const SESSION_LOGIN_PROVIDER: &str = "login_provider"; const SESSION_KEY_VERSION: &str = "key_version"; // ── Page limits ─────────────────────────────────────────────────────────────── /// Cap for HTML list (avoids loading unbounded rows into memory). const ENTRIES_PAGE_LIMIT: u32 = 50; const AUDIT_PAGE_LIMIT: i64 = 10; // ── UI language ─────────────────────────────────────────────────────────────── #[derive(Clone, Copy)] enum UiLang { ZhCn, ZhTw, En, } fn request_ui_lang(headers: &HeaderMap) -> UiLang { let Some(raw) = headers .get(header::ACCEPT_LANGUAGE) .and_then(|v| v.to_str().ok()) else { return UiLang::ZhCn; }; let lower = raw.to_ascii_lowercase(); if lower.contains("zh-tw") || lower.contains("zh-hk") || lower.contains("zh-hant") { UiLang::ZhTw } else if lower.contains("zh") { UiLang::ZhCn } else if lower.contains("en") { UiLang::En } else { UiLang::ZhCn } } fn tr(lang: UiLang, zh_cn: &'static str, zh_tw: &'static str, en: &'static str) -> &'static str { match lang { UiLang::ZhCn => zh_cn, UiLang::ZhTw => zh_tw, UiLang::En => en, } } // ── App state helpers ───────────────────────────────────────────────────────── fn google_cfg(state: &AppState) -> Option<&OAuthConfig> { state.google_config.as_ref() } async fn current_user_id(session: &Session) -> Option { match session.get::(SESSION_USER_ID).await { Ok(opt) => match opt { Some(s) => match Uuid::parse_str(&s) { Ok(id) => Some(id), Err(e) => { tracing::warn!(error = %e, user_id_str = %s, "invalid user_id UUID in session"); None } }, None => None, }, Err(e) => { tracing::warn!(error = %e, "failed to read user_id from session"); None } } } /// Load and validate the current user from session and DB. /// /// Returns the user if the session is valid. Flushes the session and returns /// `Err(Redirect::to("/login"))` when: /// - the session has no `user_id`, /// - the user no longer exists in the database, or /// - the stored `key_version` does not match the DB value (passphrase changed on /// another device since this session was created). async fn require_valid_user( pool: &sqlx::PgPool, session: &Session, context: &str, ) -> Result { let Some(user_id) = current_user_id(session).await else { return Err(Redirect::to("/login").into_response()); }; let user = match secrets_core::service::user::get_user_by_id(pool, user_id).await { Err(e) => { tracing::error!(error = %e, %user_id, context, "failed to load user"); return Err(StatusCode::INTERNAL_SERVER_ERROR.into_response()); } Ok(None) => { if let Err(e) = session.flush().await { tracing::warn!(error = %e, "failed to flush stale session"); } return Err(Redirect::to("/login").into_response()); } Ok(Some(u)) => u, }; let session_kv: Option = match session.get::(SESSION_KEY_VERSION).await { Ok(v) => v, Err(e) => { tracing::warn!(error = %e, "failed to read key_version from session; treating as missing"); None } }; if let Some(kv) = session_kv && kv != user.key_version { tracing::info!(%user_id, session_kv = kv, db_kv = user.key_version, "key_version mismatch; invalidating session"); if let Err(e) = session.flush().await { tracing::warn!(error = %e, "failed to flush outdated session"); } return Err(Redirect::to("/login").into_response()); } Ok(user) } fn request_user_agent(headers: &HeaderMap) -> Option { headers .get(header::USER_AGENT) .and_then(|value| value.to_str().ok()) .map(str::trim) .filter(|value| !value.is_empty()) .map(ToOwned::to_owned) } fn paginate(page: u32, total_count: i64, page_size: u32) -> (u32, u32, u32) { let page_size = page_size.max(1); let safe_total_count = u32::try_from(total_count.max(0)).unwrap_or(u32::MAX); let total_pages = safe_total_count.div_ceil(page_size).max(1); let current_page = page.max(1).min(total_pages); let offset = (current_page - 1).saturating_mul(page_size); (current_page, total_pages, offset) } fn render_template(tmpl: T) -> Result { let html = tmpl.render().map_err(|e| { tracing::error!(error = %e, "template render error"); StatusCode::INTERNAL_SERVER_ERROR })?; Ok(Html(html).into_response()) } // ── Routes ──────────────────────────────────────────────────────────────────── pub fn web_router() -> Router { Router::new() .route("/robots.txt", get(assets::robots_txt)) .route("/llms.txt", get(assets::llms_txt)) .route("/ai.txt", get(assets::ai_txt)) .route("/static/i18n.js", get(assets::i18n_js)) .route("/favicon.svg", get(assets::favicon_svg)) .route( "/favicon.ico", get(|| async { Redirect::permanent("/favicon.svg") }), ) .route( "/.well-known/oauth-protected-resource", get(assets::oauth_protected_resource_metadata), ) .route("/", get(auth::home_page)) .route("/login", get(auth::login_page)) .route("/auth/google", get(auth::auth_google)) .route("/auth/google/callback", get(auth::auth_google_callback)) .route("/auth/logout", post(auth::auth_logout)) .route("/dashboard", get(account::dashboard)) .route("/entries", get(entries::entries_page)) .route("/audit", get(audit::audit_page)) .route("/account/bind/google", get(auth::account_bind_google)) .route("/account/unbind/{provider}", post(auth::account_unbind)) .route("/api/key-salt", get(account::api_key_salt)) .route("/api/key-setup", post(account::api_key_setup)) .route("/api/key-change", post(account::api_key_change)) .route("/api/apikey", get(account::api_apikey_get)) .route( "/api/apikey/regenerate", post(account::api_apikey_regenerate), ) .route( "/api/entries/{id}", patch(entries::api_entry_patch).delete(entries::api_entry_delete), ) .route( "/api/entries/{entry_id}/secrets/{secret_id}", axum::routing::delete(entries::api_entry_secret_unlink), ) .route( "/api/entries/{id}/secrets/decrypt", get(entries::api_entry_secrets_decrypt), ) .route("/api/secrets/{secret_id}", patch(entries::api_secret_patch)) .route( "/api/secrets/check-name", get(entries::api_secret_check_name), ) } #[cfg(test)] mod tests { use std::net::SocketAddr; use super::*; #[test] fn client_ip_ignores_forwarded_headers_without_trusted_proxy() { let mut headers = HeaderMap::new(); headers.insert("x-forwarded-for", "203.0.113.10".parse().unwrap()); let ip = crate::client_ip::extract_client_ip_parts( &headers, SocketAddr::from(([127, 0, 0, 1], 9315)), ); assert_eq!(ip, "127.0.0.1"); } #[test] fn client_ip_uses_valid_forwarded_header_with_trusted_proxy() { // This test relies on TRUST_PROXY being unset (default); skip if set in env if std::env::var("TRUST_PROXY").is_ok() { return; } let mut headers = HeaderMap::new(); headers.insert("x-forwarded-for", "203.0.113.10, 10.0.0.1".parse().unwrap()); // Direct connection IP is used when TRUST_PROXY is not set let ip = crate::client_ip::extract_client_ip_parts( &headers, SocketAddr::from(([127, 0, 0, 1], 9315)), ); assert_eq!(ip, "127.0.0.1"); } #[test] fn request_ui_lang_prefers_zh_cn_over_en_fallback() { let mut headers = HeaderMap::new(); headers.insert(header::ACCEPT_LANGUAGE, "zh-CN, en;q=0.5".parse().unwrap()); assert!(matches!(request_ui_lang(&headers), UiLang::ZhCn)); } #[test] fn request_ui_lang_detects_traditional_chinese_variants() { let mut headers = HeaderMap::new(); headers.insert( header::ACCEPT_LANGUAGE, "zh-Hant, en;q=0.5".parse().unwrap(), ); assert!(matches!(request_ui_lang(&headers), UiLang::ZhTw)); } #[test] fn paginate_clamps_page_before_computing_offset() { let (current_page, total_pages, offset) = paginate(100, 12, 10); assert_eq!(current_page, 2); assert_eq!(total_pages, 2); assert_eq!(offset, 10); } #[test] fn paginate_handles_large_page_without_overflow() { let (current_page, total_pages, offset) = paginate(u32::MAX, 1, ENTRIES_PAGE_LIMIT); assert_eq!(current_page, 1); assert_eq!(total_pages, 1); assert_eq!(offset, 0); } #[test] fn paginate_saturates_large_total_count() { let (_, total_pages, _) = paginate(1, i64::MAX, ENTRIES_PAGE_LIMIT); assert_eq!(total_pages, u32::MAX.div_ceil(ENTRIES_PAGE_LIMIT)); } }