use anyhow::Result; use serde_json::Value; use sqlx::PgPool; use uuid::Uuid; use crate::models::{OauthAccount, User}; pub struct OAuthProfile { pub provider: String, pub provider_id: String, pub email: Option, pub name: Option, pub avatar_url: Option, } /// Find or create a user from an OAuth profile. /// Returns (user, is_new) where is_new indicates first-time registration. pub async fn find_or_create_user(pool: &PgPool, profile: OAuthProfile) -> Result<(User, bool)> { // Use a transaction with FOR UPDATE to prevent TOCTOU race conditions let mut tx = pool.begin().await?; // Check if this OAuth account already exists (with row lock) let existing: Option = sqlx::query_as( "SELECT id, user_id, provider, provider_id, email, name, avatar_url, created_at \ FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE", ) .bind(&profile.provider) .bind(&profile.provider_id) .fetch_optional(&mut *tx) .await?; if let Some(oa) = existing { let user: User = sqlx::query_as( "SELECT id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at \ FROM users WHERE id = $1", ) .bind(oa.user_id) .fetch_one(&mut *tx) .await?; tx.commit().await?; return Ok((user, false)); } // New user — create records (no key yet; user sets passphrase on dashboard) let display_name = profile .name .clone() .unwrap_or_else(|| profile.email.clone().unwrap_or_else(|| "User".to_string())); let user: User = sqlx::query_as( "INSERT INTO users (email, name, avatar_url) \ VALUES ($1, $2, $3) \ RETURNING id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at", ) .bind(&profile.email) .bind(&display_name) .bind(&profile.avatar_url) .fetch_one(&mut *tx) .await?; sqlx::query( "INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \ VALUES ($1, $2, $3, $4, $5, $6)", ) .bind(user.id) .bind(&profile.provider) .bind(&profile.provider_id) .bind(&profile.email) .bind(&profile.name) .bind(&profile.avatar_url) .execute(&mut *tx) .await?; tx.commit().await?; Ok((user, true)) } /// Re-encrypt all of a user's secrets from `old_key` to `new_key` and update the key metadata. /// /// Runs entirely inside a single database transaction: if any secret fails to re-encrypt /// the whole operation is rolled back, leaving the database unchanged. pub async fn change_user_key( pool: &PgPool, user_id: Uuid, old_key: &[u8; 32], new_key: &[u8; 32], new_salt: &[u8], new_key_check: &[u8], new_key_params: &Value, ) -> Result<()> { let mut tx = pool.begin().await?; let secrets: Vec<(uuid::Uuid, Vec)> = sqlx::query_as("SELECT id, encrypted FROM secrets WHERE user_id = $1 FOR UPDATE") .bind(user_id) .fetch_all(&mut *tx) .await?; for (id, encrypted) in &secrets { let plaintext = crate::crypto::decrypt(old_key, encrypted)?; let new_encrypted = crate::crypto::encrypt(new_key, &plaintext)?; sqlx::query("UPDATE secrets SET encrypted = $1, updated_at = NOW() WHERE id = $2") .bind(&new_encrypted) .bind(id) .execute(&mut *tx) .await?; } sqlx::query( "UPDATE users SET key_salt = $1, key_check = $2, key_params = $3, \ key_version = key_version + 1, updated_at = NOW() \ WHERE id = $4", ) .bind(new_salt) .bind(new_key_check) .bind(new_key_params) .bind(user_id) .execute(&mut *tx) .await?; tx.commit().await?; Ok(()) } /// Store the PBKDF2 salt, key_check, and params for a user's passphrase setup. pub async fn update_user_key_setup( pool: &PgPool, user_id: Uuid, key_salt: &[u8], key_check: &[u8], key_params: &Value, ) -> Result<()> { sqlx::query( "UPDATE users SET key_salt = $1, key_check = $2, key_params = $3, updated_at = NOW() \ WHERE id = $4", ) .bind(key_salt) .bind(key_check) .bind(key_params) .bind(user_id) .execute(pool) .await?; Ok(()) } /// Fetch a user by ID. pub async fn get_user_by_id(pool: &PgPool, user_id: Uuid) -> Result> { let user = sqlx::query_as( "SELECT id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at \ FROM users WHERE id = $1", ) .bind(user_id) .fetch_optional(pool) .await?; Ok(user) } /// List all OAuth accounts linked to a user. pub async fn list_oauth_accounts(pool: &PgPool, user_id: Uuid) -> Result> { let accounts = sqlx::query_as( "SELECT id, user_id, provider, provider_id, email, name, avatar_url, created_at \ FROM oauth_accounts WHERE user_id = $1 ORDER BY created_at", ) .bind(user_id) .fetch_all(pool) .await?; Ok(accounts) } /// Bind an additional OAuth account to an existing user. pub async fn bind_oauth_account( pool: &PgPool, user_id: Uuid, profile: OAuthProfile, ) -> Result { // Use a transaction with FOR UPDATE to prevent TOCTOU race conditions let mut tx = pool.begin().await?; // Check if this provider_id is already linked to someone else (with row lock) let conflict: Option<(Uuid,)> = sqlx::query_as( "SELECT user_id FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE", ) .bind(&profile.provider) .bind(&profile.provider_id) .fetch_optional(&mut *tx) .await?; if let Some((existing_user_id,)) = conflict { if existing_user_id != user_id { anyhow::bail!( "This {} account is already linked to a different user", profile.provider ); } anyhow::bail!( "This {} account is already linked to your account", profile.provider ); } let existing_provider_for_user: Option<(String,)> = sqlx::query_as( "SELECT provider_id FROM oauth_accounts WHERE user_id = $1 AND provider = $2 FOR UPDATE", ) .bind(user_id) .bind(&profile.provider) .fetch_optional(&mut *tx) .await?; if existing_provider_for_user.is_some() { anyhow::bail!( "You already linked a {} account. Unlink the other provider instead of binding multiple {} accounts.", profile.provider, profile.provider ); } let account: OauthAccount = sqlx::query_as( "INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \ VALUES ($1, $2, $3, $4, $5, $6) \ RETURNING id, user_id, provider, provider_id, email, name, avatar_url, created_at", ) .bind(user_id) .bind(&profile.provider) .bind(&profile.provider_id) .bind(&profile.email) .bind(&profile.name) .bind(&profile.avatar_url) .fetch_one(&mut *tx) .await?; tx.commit().await?; Ok(account) } /// Unbind an OAuth account. Ensures at least one remains and blocks unlinking the current login provider. pub async fn unbind_oauth_account( pool: &PgPool, user_id: Uuid, provider: &str, current_login_provider: Option<&str>, ) -> Result<()> { if current_login_provider == Some(provider) { anyhow::bail!( "Cannot unlink the {} account you are currently using to sign in", provider ); } let mut tx = pool.begin().await?; let locked_accounts: Vec<(String,)> = sqlx::query_as("SELECT provider FROM oauth_accounts WHERE user_id = $1 FOR UPDATE") .bind(user_id) .fetch_all(&mut *tx) .await?; let count = locked_accounts.len(); if count <= 1 { anyhow::bail!("Cannot unbind the last OAuth account. Please link another account first."); } sqlx::query("DELETE FROM oauth_accounts WHERE user_id = $1 AND provider = $2") .bind(user_id) .bind(provider) .execute(&mut *tx) .await?; tx.commit().await?; Ok(()) } #[cfg(test)] mod tests { use super::*; async fn maybe_test_pool() -> Option { let database_url = match std::env::var("SECRETS_DATABASE_URL") { Ok(v) => v, Err(_) => { eprintln!("skip user service tests: SECRETS_DATABASE_URL not set"); return None; } }; let pool = match sqlx::PgPool::connect(&database_url).await { Ok(pool) => pool, Err(e) => { eprintln!("skip user service tests: cannot connect to database: {e}"); return None; } }; if let Err(e) = crate::db::migrate(&pool).await { eprintln!("skip user service tests: migrate failed: {e}"); return None; } Some(pool) } async fn cleanup_user_rows(pool: &PgPool, user_id: Uuid) -> Result<()> { sqlx::query("DELETE FROM oauth_accounts WHERE user_id = $1") .bind(user_id) .execute(pool) .await?; sqlx::query("DELETE FROM users WHERE id = $1") .bind(user_id) .execute(pool) .await?; Ok(()) } #[tokio::test] async fn unbind_oauth_account_removes_only_requested_provider() -> Result<()> { let Some(pool) = maybe_test_pool().await else { return Ok(()); }; let user_id = Uuid::from_u128(rand::random()); cleanup_user_rows(&pool, user_id).await?; sqlx::query("INSERT INTO users (id, name) VALUES ($1, '')") .bind(user_id) .execute(&pool) .await?; sqlx::query( "INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \ VALUES ($1, 'google', $2, NULL, NULL, NULL), \ ($1, 'github', $3, NULL, NULL, NULL)", ) .bind(user_id) .bind(format!("google-{user_id}")) .bind(format!("github-{user_id}")) .execute(&pool) .await?; unbind_oauth_account(&pool, user_id, "github", Some("google")).await?; let remaining: Vec<(String,)> = sqlx::query_as( "SELECT provider FROM oauth_accounts WHERE user_id = $1 ORDER BY provider", ) .bind(user_id) .fetch_all(&pool) .await?; assert_eq!(remaining, vec![("google".to_string(),)]); cleanup_user_rows(&pool, user_id).await?; Ok(()) } }