use std::net::SocketAddr; use askama::Template; use axum::{ extract::{ConnectInfo, Path, Query, State}, http::{HeaderMap, StatusCode}, response::{IntoResponse, Redirect, Response}, }; use serde::Deserialize; use tower_sessions::Session; use secrets_core::audit::log_login; use secrets_core::service::user::{ OAuthProfile, bind_oauth_account, find_or_create_user, unbind_oauth_account, }; use crate::AppState; use crate::oauth::{OAuthConfig, OAuthUserInfo, google_auth_url, random_state}; use super::{ SESSION_KEY_VERSION, SESSION_LOGIN_PROVIDER, SESSION_OAUTH_BIND_MODE, SESSION_OAUTH_STATE, SESSION_USER_ID, current_user_id, google_cfg, render_template, request_user_agent, }; #[derive(Template)] #[template(path = "login.html")] struct LoginTemplate { has_google: bool, base_url: String, version: &'static str, } #[derive(Template)] #[template(path = "home.html")] struct HomeTemplate { is_logged_in: bool, base_url: String, version: &'static str, } // ── Home page (public) ─────────────────────────────────────────────────────── pub(super) async fn home_page( State(state): State, session: Session, ) -> Result { let is_logged_in = current_user_id(&session).await.is_some(); let tmpl = HomeTemplate { is_logged_in, base_url: state.base_url.clone(), version: env!("CARGO_PKG_VERSION"), }; render_template(tmpl) } // ── Login page ──────────────────────────────────────────────────────────────── pub(super) async fn login_page( State(state): State, session: Session, ) -> Result { if let Some(_uid) = current_user_id(&session).await { return Ok(Redirect::to("/dashboard").into_response()); } let tmpl = LoginTemplate { has_google: state.google_config.is_some(), base_url: state.base_url.clone(), version: env!("CARGO_PKG_VERSION"), }; render_template(tmpl) } // ── Google OAuth ────────────────────────────────────────────────────────────── pub(super) async fn auth_google( State(state): State, session: Session, ) -> Result { let config = google_cfg(&state).ok_or(StatusCode::SERVICE_UNAVAILABLE)?; let oauth_state = random_state(); session .insert(SESSION_OAUTH_STATE, &oauth_state) .await .map_err(|e| { tracing::error!(error = %e, "failed to insert oauth_state into session"); StatusCode::INTERNAL_SERVER_ERROR })?; let url = google_auth_url(config, &oauth_state); Ok(Redirect::to(&url).into_response()) } #[derive(Deserialize)] pub(super) struct OAuthCallbackQuery { code: Option, state: Option, error: Option, } pub(super) async fn auth_google_callback( State(state): State, connect_info: ConnectInfo, headers: HeaderMap, session: Session, Query(params): Query, ) -> Result { let client_ip = Some(crate::client_ip::extract_client_ip_parts( &headers, connect_info.0, )); let user_agent = request_user_agent(&headers); handle_oauth_callback( &state, &session, params, "google", client_ip.as_deref(), user_agent.as_deref(), |s, cfg, code| { Box::pin(crate::oauth::google::exchange_code( &s.http_client, cfg, code, )) }, ) .await } // ── Shared OAuth callback handler ───────────────────────────────────────────── async fn handle_oauth_callback( state: &AppState, session: &Session, params: OAuthCallbackQuery, provider: &str, client_ip: Option<&str>, user_agent: Option<&str>, exchange_fn: F, ) -> Result where F: for<'a> Fn( &'a AppState, &'a OAuthConfig, &'a str, ) -> std::pin::Pin< Box> + Send + 'a>, >, { if let Some(err) = params.error { tracing::warn!(provider, error = %err, "OAuth error"); return Ok(Redirect::to("/login?error=oauth_error").into_response()); } let Some(code) = params.code else { tracing::warn!(provider, "OAuth callback missing code"); return Ok(Redirect::to("/login?error=oauth_missing_code").into_response()); }; let Some(returned_state) = params.state.as_deref() else { tracing::warn!(provider, "OAuth callback missing state"); return Ok(Redirect::to("/login?error=oauth_missing_state").into_response()); }; let expected_state: Option = session.get(SESSION_OAUTH_STATE).await.map_err(|e| { tracing::error!(provider, error = %e, "failed to read oauth_state from session"); StatusCode::INTERNAL_SERVER_ERROR })?; if expected_state.as_deref() != Some(returned_state) { tracing::warn!( provider, expected_present = expected_state.is_some(), "OAuth state mismatch (empty session often means SameSite=Strict or server restart)" ); return Ok(Redirect::to("/login?error=oauth_state").into_response()); } if let Err(e) = session.remove::(SESSION_OAUTH_STATE).await { tracing::warn!(provider, error = %e, "failed to remove oauth_state from session"); } let config = match provider { "google" => state .google_config .as_ref() .ok_or(StatusCode::SERVICE_UNAVAILABLE)?, _ => return Err(StatusCode::BAD_REQUEST), }; let user_info = exchange_fn(state, config, code.as_str()) .await .map_err(|e| { tracing::error!(provider, error = %e, "failed to exchange OAuth code"); StatusCode::INTERNAL_SERVER_ERROR })?; let bind_mode: bool = match session.get::(SESSION_OAUTH_BIND_MODE).await { Ok(v) => v.unwrap_or(false), Err(e) => { tracing::error!( provider, error = %e, "failed to read oauth_bind_mode from session" ); return Err(StatusCode::INTERNAL_SERVER_ERROR); } }; if bind_mode { let user_id = current_user_id(session) .await .ok_or(StatusCode::UNAUTHORIZED)?; if let Err(e) = session.remove::(SESSION_OAUTH_BIND_MODE).await { tracing::warn!(provider, error = %e, "failed to remove oauth_bind_mode from session after bind"); } let profile = OAuthProfile { provider: user_info.provider, provider_id: user_info.provider_id, email: user_info.email, name: user_info.name, avatar_url: user_info.avatar_url, }; bind_oauth_account(&state.pool, user_id, profile) .await .map_err(|e| { tracing::error!(error = %e, "failed to bind OAuth account"); StatusCode::INTERNAL_SERVER_ERROR })?; return Ok(Redirect::to("/dashboard?bound=1").into_response()); } let profile = OAuthProfile { provider: user_info.provider, provider_id: user_info.provider_id, email: user_info.email, name: user_info.name, avatar_url: user_info.avatar_url, }; let (user, _is_new) = find_or_create_user(&state.pool, profile) .await .map_err(|e| { tracing::error!(error = %e, "failed to find or create user"); StatusCode::INTERNAL_SERVER_ERROR })?; session .insert(SESSION_USER_ID, user.id.to_string()) .await .map_err(|e| { tracing::error!( error = %e, user_id = %user.id, "failed to insert user_id into session after OAuth" ); StatusCode::INTERNAL_SERVER_ERROR })?; session .insert(SESSION_LOGIN_PROVIDER, &provider) .await .map_err(|e| { tracing::error!( provider, error = %e, "failed to insert login_provider into session after OAuth" ); StatusCode::INTERNAL_SERVER_ERROR })?; if let Err(e) = session.insert(SESSION_KEY_VERSION, user.key_version).await { tracing::warn!(error = %e, user_id = %user.id, "failed to insert key_version into session after OAuth"); } log_login( &state.pool, "oauth", provider, user.id, client_ip, user_agent, ) .await; Ok(Redirect::to("/dashboard").into_response()) } // ── Logout ──────────────────────────────────────────────────────────────────── pub(super) async fn auth_logout(session: Session) -> impl IntoResponse { if let Err(e) = session.flush().await { tracing::warn!(error = %e, "failed to flush session on logout"); } Redirect::to("/") } // ── Account bind/unbind ─────────────────────────────────────────────────────── pub(super) async fn account_bind_google( State(state): State, session: Session, ) -> Result { let _ = current_user_id(&session) .await .ok_or(StatusCode::UNAUTHORIZED)?; session .insert(SESSION_OAUTH_BIND_MODE, true) .await .map_err(|e| { tracing::error!(error = %e, "failed to insert oauth_bind_mode into session"); StatusCode::INTERNAL_SERVER_ERROR })?; let config = google_cfg(&state).ok_or(StatusCode::SERVICE_UNAVAILABLE)?; let oauth_state = random_state(); if let Err(e) = session.insert(SESSION_OAUTH_STATE, &oauth_state).await { tracing::error!(error = %e, "failed to insert oauth_state for account bind flow"); if let Err(rm) = session.remove::(SESSION_OAUTH_BIND_MODE).await { tracing::warn!(error = %rm, "failed to roll back oauth_bind_mode after oauth_state insert failure"); } return Err(StatusCode::INTERNAL_SERVER_ERROR); } let url = google_auth_url(config, &oauth_state); Ok(Redirect::to(&url).into_response()) } pub(super) async fn account_unbind( State(state): State, Path(provider): Path, session: Session, ) -> Result { let user_id = current_user_id(&session) .await .ok_or(StatusCode::UNAUTHORIZED)?; let current_login_provider = session .get::(SESSION_LOGIN_PROVIDER) .await .map_err(|e| { tracing::error!(error = %e, "failed to read login_provider from session"); StatusCode::INTERNAL_SERVER_ERROR })?; unbind_oauth_account( &state.pool, user_id, &provider, current_login_provider.as_deref(), ) .await .map_err(|e| { tracing::warn!(error = %e, "failed to unbind oauth account"); StatusCode::BAD_REQUEST })?; Ok(Redirect::to("/dashboard?unbound=1").into_response()) }