350 lines
11 KiB
Rust
350 lines
11 KiB
Rust
use anyhow::Result;
|
|
use serde_json::Value;
|
|
use sqlx::PgPool;
|
|
use uuid::Uuid;
|
|
|
|
use crate::models::{OauthAccount, User};
|
|
|
|
pub struct OAuthProfile {
|
|
pub provider: String,
|
|
pub provider_id: String,
|
|
pub email: Option<String>,
|
|
pub name: Option<String>,
|
|
pub avatar_url: Option<String>,
|
|
}
|
|
|
|
/// Find or create a user from an OAuth profile.
|
|
/// Returns (user, is_new) where is_new indicates first-time registration.
|
|
pub async fn find_or_create_user(pool: &PgPool, profile: OAuthProfile) -> Result<(User, bool)> {
|
|
// Use a transaction with FOR UPDATE to prevent TOCTOU race conditions
|
|
let mut tx = pool.begin().await?;
|
|
|
|
// Check if this OAuth account already exists (with row lock)
|
|
let existing: Option<OauthAccount> = sqlx::query_as(
|
|
"SELECT id, user_id, provider, provider_id, email, name, avatar_url, created_at \
|
|
FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE",
|
|
)
|
|
.bind(&profile.provider)
|
|
.bind(&profile.provider_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
if let Some(oa) = existing {
|
|
let user: User = sqlx::query_as(
|
|
"SELECT id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at \
|
|
FROM users WHERE id = $1",
|
|
)
|
|
.bind(oa.user_id)
|
|
.fetch_one(&mut *tx)
|
|
.await?;
|
|
tx.commit().await?;
|
|
return Ok((user, false));
|
|
}
|
|
|
|
// New user — create records (no key yet; user sets passphrase on dashboard)
|
|
let display_name = profile
|
|
.name
|
|
.clone()
|
|
.unwrap_or_else(|| profile.email.clone().unwrap_or_else(|| "User".to_string()));
|
|
|
|
let user: User = sqlx::query_as(
|
|
"INSERT INTO users (email, name, avatar_url) \
|
|
VALUES ($1, $2, $3) \
|
|
RETURNING id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at",
|
|
)
|
|
.bind(&profile.email)
|
|
.bind(&display_name)
|
|
.bind(&profile.avatar_url)
|
|
.fetch_one(&mut *tx)
|
|
.await?;
|
|
|
|
sqlx::query(
|
|
"INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \
|
|
VALUES ($1, $2, $3, $4, $5, $6)",
|
|
)
|
|
.bind(user.id)
|
|
.bind(&profile.provider)
|
|
.bind(&profile.provider_id)
|
|
.bind(&profile.email)
|
|
.bind(&profile.name)
|
|
.bind(&profile.avatar_url)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
|
|
Ok((user, true))
|
|
}
|
|
|
|
/// Re-encrypt all of a user's secrets from `old_key` to `new_key` and update the key metadata.
|
|
///
|
|
/// Runs entirely inside a single database transaction: if any secret fails to re-encrypt
|
|
/// the whole operation is rolled back, leaving the database unchanged.
|
|
pub async fn change_user_key(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
old_key: &[u8; 32],
|
|
new_key: &[u8; 32],
|
|
new_salt: &[u8],
|
|
new_key_check: &[u8],
|
|
new_key_params: &Value,
|
|
) -> Result<()> {
|
|
let mut tx = pool.begin().await?;
|
|
|
|
let secrets: Vec<(uuid::Uuid, Vec<u8>)> =
|
|
sqlx::query_as("SELECT id, encrypted FROM secrets WHERE user_id = $1 FOR UPDATE")
|
|
.bind(user_id)
|
|
.fetch_all(&mut *tx)
|
|
.await?;
|
|
|
|
for (id, encrypted) in &secrets {
|
|
let plaintext = crate::crypto::decrypt(old_key, encrypted)?;
|
|
let new_encrypted = crate::crypto::encrypt(new_key, &plaintext)?;
|
|
sqlx::query("UPDATE secrets SET encrypted = $1, updated_at = NOW() WHERE id = $2")
|
|
.bind(&new_encrypted)
|
|
.bind(id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
}
|
|
|
|
sqlx::query(
|
|
"UPDATE users SET key_salt = $1, key_check = $2, key_params = $3, \
|
|
key_version = key_version + 1, updated_at = NOW() \
|
|
WHERE id = $4",
|
|
)
|
|
.bind(new_salt)
|
|
.bind(new_key_check)
|
|
.bind(new_key_params)
|
|
.bind(user_id)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Store the PBKDF2 salt, key_check, and params for a user's passphrase setup.
|
|
pub async fn update_user_key_setup(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
key_salt: &[u8],
|
|
key_check: &[u8],
|
|
key_params: &Value,
|
|
) -> Result<()> {
|
|
sqlx::query(
|
|
"UPDATE users SET key_salt = $1, key_check = $2, key_params = $3, updated_at = NOW() \
|
|
WHERE id = $4",
|
|
)
|
|
.bind(key_salt)
|
|
.bind(key_check)
|
|
.bind(key_params)
|
|
.bind(user_id)
|
|
.execute(pool)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Fetch a user by ID.
|
|
pub async fn get_user_by_id(pool: &PgPool, user_id: Uuid) -> Result<Option<User>> {
|
|
let user = sqlx::query_as(
|
|
"SELECT id, email, name, avatar_url, key_salt, key_check, key_params, api_key, key_version, created_at, updated_at \
|
|
FROM users WHERE id = $1",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
Ok(user)
|
|
}
|
|
|
|
/// List all OAuth accounts linked to a user.
|
|
pub async fn list_oauth_accounts(pool: &PgPool, user_id: Uuid) -> Result<Vec<OauthAccount>> {
|
|
let accounts = sqlx::query_as(
|
|
"SELECT id, user_id, provider, provider_id, email, name, avatar_url, created_at \
|
|
FROM oauth_accounts WHERE user_id = $1 ORDER BY created_at",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_all(pool)
|
|
.await?;
|
|
Ok(accounts)
|
|
}
|
|
|
|
/// Bind an additional OAuth account to an existing user.
|
|
pub async fn bind_oauth_account(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
profile: OAuthProfile,
|
|
) -> Result<OauthAccount> {
|
|
// Use a transaction with FOR UPDATE to prevent TOCTOU race conditions
|
|
let mut tx = pool.begin().await?;
|
|
|
|
// Check if this provider_id is already linked to someone else (with row lock)
|
|
let conflict: Option<(Uuid,)> = sqlx::query_as(
|
|
"SELECT user_id FROM oauth_accounts WHERE provider = $1 AND provider_id = $2 FOR UPDATE",
|
|
)
|
|
.bind(&profile.provider)
|
|
.bind(&profile.provider_id)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
if let Some((existing_user_id,)) = conflict {
|
|
if existing_user_id != user_id {
|
|
anyhow::bail!(
|
|
"This {} account is already linked to a different user",
|
|
profile.provider
|
|
);
|
|
}
|
|
anyhow::bail!(
|
|
"This {} account is already linked to your account",
|
|
profile.provider
|
|
);
|
|
}
|
|
|
|
let existing_provider_for_user: Option<(String,)> = sqlx::query_as(
|
|
"SELECT provider_id FROM oauth_accounts WHERE user_id = $1 AND provider = $2 FOR UPDATE",
|
|
)
|
|
.bind(user_id)
|
|
.bind(&profile.provider)
|
|
.fetch_optional(&mut *tx)
|
|
.await?;
|
|
|
|
if existing_provider_for_user.is_some() {
|
|
anyhow::bail!(
|
|
"You already linked a {} account. Unlink the other provider instead of binding multiple {} accounts.",
|
|
profile.provider,
|
|
profile.provider
|
|
);
|
|
}
|
|
|
|
let account: OauthAccount = sqlx::query_as(
|
|
"INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \
|
|
VALUES ($1, $2, $3, $4, $5, $6) \
|
|
RETURNING id, user_id, provider, provider_id, email, name, avatar_url, created_at",
|
|
)
|
|
.bind(user_id)
|
|
.bind(&profile.provider)
|
|
.bind(&profile.provider_id)
|
|
.bind(&profile.email)
|
|
.bind(&profile.name)
|
|
.bind(&profile.avatar_url)
|
|
.fetch_one(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
Ok(account)
|
|
}
|
|
|
|
/// Unbind an OAuth account. Ensures at least one remains and blocks unlinking the current login provider.
|
|
pub async fn unbind_oauth_account(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
provider: &str,
|
|
current_login_provider: Option<&str>,
|
|
) -> Result<()> {
|
|
if current_login_provider == Some(provider) {
|
|
anyhow::bail!(
|
|
"Cannot unlink the {} account you are currently using to sign in",
|
|
provider
|
|
);
|
|
}
|
|
|
|
let mut tx = pool.begin().await?;
|
|
|
|
let locked_accounts: Vec<(String,)> =
|
|
sqlx::query_as("SELECT provider FROM oauth_accounts WHERE user_id = $1 FOR UPDATE")
|
|
.bind(user_id)
|
|
.fetch_all(&mut *tx)
|
|
.await?;
|
|
let count = locked_accounts.len();
|
|
|
|
if count <= 1 {
|
|
anyhow::bail!("Cannot unbind the last OAuth account. Please link another account first.");
|
|
}
|
|
|
|
sqlx::query("DELETE FROM oauth_accounts WHERE user_id = $1 AND provider = $2")
|
|
.bind(user_id)
|
|
.bind(provider)
|
|
.execute(&mut *tx)
|
|
.await?;
|
|
|
|
tx.commit().await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
async fn maybe_test_pool() -> Option<PgPool> {
|
|
let database_url = match std::env::var("SECRETS_DATABASE_URL") {
|
|
Ok(v) => v,
|
|
Err(_) => {
|
|
eprintln!("skip user service tests: SECRETS_DATABASE_URL not set");
|
|
return None;
|
|
}
|
|
};
|
|
let pool = match sqlx::PgPool::connect(&database_url).await {
|
|
Ok(pool) => pool,
|
|
Err(e) => {
|
|
eprintln!("skip user service tests: cannot connect to database: {e}");
|
|
return None;
|
|
}
|
|
};
|
|
if let Err(e) = crate::db::migrate(&pool).await {
|
|
eprintln!("skip user service tests: migrate failed: {e}");
|
|
return None;
|
|
}
|
|
Some(pool)
|
|
}
|
|
|
|
async fn cleanup_user_rows(pool: &PgPool, user_id: Uuid) -> Result<()> {
|
|
sqlx::query("DELETE FROM oauth_accounts WHERE user_id = $1")
|
|
.bind(user_id)
|
|
.execute(pool)
|
|
.await?;
|
|
sqlx::query("DELETE FROM users WHERE id = $1")
|
|
.bind(user_id)
|
|
.execute(pool)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn unbind_oauth_account_removes_only_requested_provider() -> Result<()> {
|
|
let Some(pool) = maybe_test_pool().await else {
|
|
return Ok(());
|
|
};
|
|
let user_id = Uuid::from_u128(rand::random());
|
|
|
|
cleanup_user_rows(&pool, user_id).await?;
|
|
|
|
sqlx::query("INSERT INTO users (id, name) VALUES ($1, '')")
|
|
.bind(user_id)
|
|
.execute(&pool)
|
|
.await?;
|
|
sqlx::query(
|
|
"INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name, avatar_url) \
|
|
VALUES ($1, 'google', $2, NULL, NULL, NULL), \
|
|
($1, 'github', $3, NULL, NULL, NULL)",
|
|
)
|
|
.bind(user_id)
|
|
.bind(format!("google-{user_id}"))
|
|
.bind(format!("github-{user_id}"))
|
|
.execute(&pool)
|
|
.await?;
|
|
|
|
unbind_oauth_account(&pool, user_id, "github", Some("google")).await?;
|
|
|
|
let remaining: Vec<(String,)> = sqlx::query_as(
|
|
"SELECT provider FROM oauth_accounts WHERE user_id = $1 ORDER BY provider",
|
|
)
|
|
.bind(user_id)
|
|
.fetch_all(&pool)
|
|
.await?;
|
|
assert_eq!(remaining, vec![("google".to_string(),)]);
|
|
|
|
cleanup_user_rows(&pool, user_id).await?;
|
|
Ok(())
|
|
}
|
|
}
|