1100 lines
32 KiB
Rust
1100 lines
32 KiB
Rust
use anyhow::{Context, Result as AnyResult};
|
|
use axum::{
|
|
Json, Router,
|
|
extract::{Path, Query, State},
|
|
http::{HeaderMap, StatusCode, header},
|
|
response::{Html, IntoResponse, Redirect},
|
|
routing::{get, post},
|
|
};
|
|
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
|
|
use chrono::{DateTime, Duration, Utc};
|
|
use reqwest::Client;
|
|
use secrets_application::sync::{fetch_object, sync_pull, sync_push};
|
|
use secrets_device_auth::{hash_device_login_token, new_device_fingerprint, new_device_login_token};
|
|
use secrets_domain::{
|
|
SyncPullRequest, SyncPullResponse, SyncPushRequest, SyncPushResponse, VaultObjectEnvelope,
|
|
};
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::json;
|
|
use sha2::{Digest, Sha256};
|
|
use sqlx::{PgPool, Postgres, Transaction};
|
|
use tracing_subscriber::EnvFilter;
|
|
use url::Url;
|
|
use uuid::Uuid;
|
|
|
|
#[derive(Clone)]
|
|
struct AppState {
|
|
pool: PgPool,
|
|
http: Client,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct DemoLoginResponse {
|
|
device_token: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DesktopLoginStartRequest {
|
|
device_name: String,
|
|
platform: String,
|
|
client_version: String,
|
|
device_fingerprint: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DesktopLoginPollQuery {
|
|
session_id: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GoogleStartQuery {
|
|
session_id: String,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GoogleCallbackQuery {
|
|
state: Option<String>,
|
|
code: Option<String>,
|
|
error: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GoogleUserInfo {
|
|
sub: String,
|
|
email: String,
|
|
name: Option<String>,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
struct GoogleOAuthConfig {
|
|
client_id: String,
|
|
client_secret: String,
|
|
auth_uri: String,
|
|
token_uri: String,
|
|
redirect_uri: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct DesktopLoginStartResponse {
|
|
session_id: String,
|
|
auth_url: String,
|
|
expires_at: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct DesktopLoginPollResponse {
|
|
status: String,
|
|
device_token: Option<String>,
|
|
error: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct GoogleTokenResponse {
|
|
access_token: String,
|
|
}
|
|
|
|
#[derive(Debug, sqlx::FromRow)]
|
|
struct DesktopLoginSessionRow {
|
|
session_id: String,
|
|
oauth_state: String,
|
|
pkce_verifier: String,
|
|
device_name: String,
|
|
platform: String,
|
|
client_version: String,
|
|
device_fingerprint: String,
|
|
status: String,
|
|
error_message: Option<String>,
|
|
expires_at: DateTime<Utc>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct DeviceView {
|
|
name: String,
|
|
platform: String,
|
|
client_version: String,
|
|
last_seen: String,
|
|
ip: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct UserProfileView {
|
|
id: Uuid,
|
|
name: String,
|
|
email: String,
|
|
}
|
|
|
|
#[derive(Serialize, sqlx::FromRow)]
|
|
struct UserRow {
|
|
id: Uuid,
|
|
email: Option<String>,
|
|
name: String,
|
|
}
|
|
|
|
#[derive(Serialize, sqlx::FromRow)]
|
|
struct DeviceRow {
|
|
id: Uuid,
|
|
display_name: String,
|
|
platform: String,
|
|
client_version: String,
|
|
last_seen_at: DateTime<Utc>,
|
|
last_ip: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Serialize)]
|
|
struct ObjectResponse {
|
|
object: VaultObjectEnvelope,
|
|
}
|
|
|
|
const LOGIN_STATUS_PENDING: &str = "pending";
|
|
const LOGIN_STATUS_SUCCEEDED: &str = "succeeded";
|
|
const LOGIN_STATUS_FAILED: &str = "failed";
|
|
const LOGIN_STATUS_EXPIRED: &str = "expired";
|
|
const LOGIN_STATUS_CONSUMED: &str = "consumed";
|
|
const DESKTOP_LOGIN_SESSION_TTL_MINUTES: i64 = 10;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> AnyResult<()> {
|
|
let _ = dotenvy::dotenv();
|
|
|
|
tracing_subscriber::fmt()
|
|
.with_env_filter(
|
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| "secrets_api=info".into()),
|
|
)
|
|
.init();
|
|
|
|
let database_url = secrets_infrastructure_db::load_database_url()?;
|
|
let pool = secrets_infrastructure_db::create_pool(&database_url).await?;
|
|
secrets_infrastructure_db::migrate_current_schema(&pool)
|
|
.await
|
|
.context("failed to initialize current database schema")?;
|
|
|
|
let bind = std::env::var("SECRETS_API_BIND").unwrap_or_else(|_| "127.0.0.1:9415".to_string());
|
|
let app = Router::new()
|
|
.route("/healthz", get(|| async { "ok" }))
|
|
.route("/auth/demo-login", post(api_demo_login))
|
|
.route("/auth/desktop/start", post(api_desktop_login_start))
|
|
.route("/auth/desktop/poll", get(api_desktop_login_poll))
|
|
.route("/auth/google/start", get(api_google_login_start))
|
|
.route("/auth/google/callback", get(api_google_login_callback))
|
|
.route("/me", get(api_me))
|
|
.route("/sync/pull", post(api_sync_pull))
|
|
.route("/sync/push", post(api_sync_push))
|
|
.route("/sync/objects/{id}", get(api_sync_object))
|
|
.route("/devices", get(api_devices))
|
|
.with_state(AppState {
|
|
pool,
|
|
http: Client::new(),
|
|
});
|
|
let listener = tokio::net::TcpListener::bind(&bind)
|
|
.await
|
|
.with_context(|| format!("failed to bind {}", bind))?;
|
|
|
|
tracing::info!(bind = %bind, "secrets-api listening");
|
|
axum::serve(listener, app)
|
|
.await
|
|
.context("api server error")?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn api_demo_login(
|
|
State(state): State<AppState>,
|
|
) -> std::result::Result<Json<DemoLoginResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user_id, device_id) = ensure_demo_user(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
let device_token = new_device_login_token();
|
|
let token_hash = hash_device_login_token(&device_token);
|
|
|
|
sqlx::query("DELETE FROM device_login_tokens WHERE device_id = $1")
|
|
.bind(device_id)
|
|
.execute(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO device_login_tokens (device_id, token_hash)
|
|
VALUES ($1, $2)
|
|
"#,
|
|
)
|
|
.bind(device_id)
|
|
.bind(token_hash)
|
|
.execute(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO auth_events (
|
|
user_id, device_id, device_name, platform, client_version, ip_addr, forwarded_ip, login_method, login_result
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, 'device_token', 'success')
|
|
"#,
|
|
)
|
|
.bind(user_id)
|
|
.bind(device_id)
|
|
.bind("Voson 的 Mac mini")
|
|
.bind("macOS")
|
|
.bind(env!("CARGO_PKG_VERSION"))
|
|
.bind::<Option<String>>(None)
|
|
.bind::<Option<String>>(None)
|
|
.execute(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
Ok(Json(DemoLoginResponse { device_token }))
|
|
}
|
|
|
|
async fn api_desktop_login_start(
|
|
State(state): State<AppState>,
|
|
Json(payload): Json<DesktopLoginStartRequest>,
|
|
) -> std::result::Result<Json<DesktopLoginStartResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let session_id = new_session_secret();
|
|
let oauth_state = new_session_secret();
|
|
let pkce_verifier = new_session_secret();
|
|
let expires_at = Utc::now() + Duration::minutes(DESKTOP_LOGIN_SESSION_TTL_MINUTES);
|
|
let auth_url = format!(
|
|
"{}/auth/google/start?session_id={}",
|
|
public_base_url().map_err(internal_error)?,
|
|
session_id
|
|
);
|
|
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO desktop_login_sessions (
|
|
session_id, oauth_state, pkce_verifier, device_name, platform, client_version,
|
|
device_fingerprint, status, expires_at
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
|
"#,
|
|
)
|
|
.bind(&session_id)
|
|
.bind(&oauth_state)
|
|
.bind(&pkce_verifier)
|
|
.bind(&payload.device_name)
|
|
.bind(&payload.platform)
|
|
.bind(&payload.client_version)
|
|
.bind(&payload.device_fingerprint)
|
|
.bind(LOGIN_STATUS_PENDING)
|
|
.bind(expires_at)
|
|
.execute(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
Ok(Json(DesktopLoginStartResponse {
|
|
session_id,
|
|
auth_url,
|
|
expires_at: expires_at.to_rfc3339(),
|
|
}))
|
|
}
|
|
|
|
async fn api_google_login_start(
|
|
State(state): State<AppState>,
|
|
Query(query): Query<GoogleStartQuery>,
|
|
) -> std::result::Result<Redirect, (StatusCode, Json<serde_json::Value>)> {
|
|
let session = fetch_desktop_login_session(&state.pool, &query.session_id)
|
|
.await
|
|
.map_err(internal_error)?
|
|
.ok_or_else(|| unauthorized("desktop login session not found"))?;
|
|
ensure_login_session_pending(&session).map_err(unauthorized)?;
|
|
let google = google_oauth_config().map_err(internal_error)?;
|
|
let challenge = pkce_challenge(&session.pkce_verifier);
|
|
|
|
let mut auth_url = Url::parse(&google.auth_uri).map_err(internal_error)?;
|
|
auth_url
|
|
.query_pairs_mut()
|
|
.append_pair("client_id", &google.client_id)
|
|
.append_pair("redirect_uri", &google.redirect_uri)
|
|
.append_pair("response_type", "code")
|
|
.append_pair("scope", "openid email profile")
|
|
.append_pair("state", &session.oauth_state)
|
|
.append_pair("code_challenge", &challenge)
|
|
.append_pair("code_challenge_method", "S256")
|
|
.append_pair("access_type", "offline")
|
|
.append_pair("prompt", "consent");
|
|
|
|
Ok(Redirect::temporary(auth_url.as_str()))
|
|
}
|
|
|
|
async fn api_google_login_callback(
|
|
State(state): State<AppState>,
|
|
Query(query): Query<GoogleCallbackQuery>,
|
|
) -> std::result::Result<impl IntoResponse, (StatusCode, Json<serde_json::Value>)> {
|
|
let oauth_state = query
|
|
.state
|
|
.as_deref()
|
|
.filter(|value| !value.is_empty())
|
|
.ok_or_else(|| unauthorized("missing oauth state"))?;
|
|
let mut tx = state.pool.begin().await.map_err(internal_error)?;
|
|
let session = fetch_desktop_login_session_by_state(&mut tx, oauth_state)
|
|
.await
|
|
.map_err(internal_error)?
|
|
.ok_or_else(|| unauthorized("desktop login session not found"))?;
|
|
|
|
if let Some(error) = query.error.as_deref().filter(|value| !value.is_empty()) {
|
|
mark_login_session_failed(&mut tx, &session.session_id, &format!("google oauth error: {error}"))
|
|
.await
|
|
.map_err(internal_error)?;
|
|
tx.commit().await.map_err(internal_error)?;
|
|
return Ok(Html(login_result_html(
|
|
"登录未完成",
|
|
"你已取消 Google 授权或授权未成功,可以返回 Secrets 重试。",
|
|
)));
|
|
}
|
|
|
|
ensure_login_session_pending(&session).map_err(unauthorized)?;
|
|
let code = query
|
|
.code
|
|
.as_deref()
|
|
.filter(|value| !value.is_empty())
|
|
.ok_or_else(|| unauthorized("missing google auth code"))?;
|
|
let google = google_oauth_config().map_err(internal_error)?;
|
|
let google_token = exchange_google_auth_code(&state.http, &google, code, &session.pkce_verifier)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
let google_user = fetch_google_userinfo(&state.http, &google_token.access_token)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
let user_id = upsert_user_from_google(&state.pool, &google_user)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
upsert_google_oauth_account(&state.pool, user_id, &google_user)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
let device_id = upsert_device_for_login(
|
|
&state.pool,
|
|
user_id,
|
|
&session.device_name,
|
|
&session.platform,
|
|
&session.client_version,
|
|
&session.device_fingerprint,
|
|
)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
let device_token = issue_device_login_token(
|
|
&state.pool,
|
|
user_id,
|
|
device_id,
|
|
&session.device_name,
|
|
&session.platform,
|
|
&session.client_version,
|
|
)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
mark_login_session_succeeded(
|
|
&mut tx,
|
|
&session.session_id,
|
|
user_id,
|
|
device_id,
|
|
device_token.clone(),
|
|
hash_device_login_token(&device_token),
|
|
)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
tx.commit().await.map_err(internal_error)?;
|
|
|
|
Ok(Html(login_result_html(
|
|
"登录成功",
|
|
"Google 授权已完成,可以返回 Secrets 桌面端继续。",
|
|
)))
|
|
}
|
|
|
|
async fn api_desktop_login_poll(
|
|
State(state): State<AppState>,
|
|
Query(query): Query<DesktopLoginPollQuery>,
|
|
) -> std::result::Result<Json<DesktopLoginPollResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let mut tx = state.pool.begin().await.map_err(internal_error)?;
|
|
let session = fetch_desktop_login_session_for_update(&mut tx, &query.session_id)
|
|
.await
|
|
.map_err(internal_error)?
|
|
.ok_or_else(|| unauthorized("desktop login session not found"))?;
|
|
let now = Utc::now();
|
|
|
|
if session.expires_at < now && session.status == LOGIN_STATUS_PENDING {
|
|
mark_login_session_expired(&mut tx, &session.session_id)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
tx.commit().await.map_err(internal_error)?;
|
|
return Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_EXPIRED.to_string(),
|
|
device_token: None,
|
|
error: Some("login session expired".to_string()),
|
|
}));
|
|
}
|
|
|
|
match session.status.as_str() {
|
|
LOGIN_STATUS_PENDING => {
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_PENDING.to_string(),
|
|
device_token: None,
|
|
error: None,
|
|
}))
|
|
}
|
|
LOGIN_STATUS_FAILED => {
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_FAILED.to_string(),
|
|
device_token: None,
|
|
error: session.error_message,
|
|
}))
|
|
}
|
|
LOGIN_STATUS_EXPIRED => {
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_EXPIRED.to_string(),
|
|
device_token: None,
|
|
error: session.error_message.or(Some("login session expired".to_string())),
|
|
}))
|
|
}
|
|
LOGIN_STATUS_CONSUMED => {
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_CONSUMED.to_string(),
|
|
device_token: None,
|
|
error: Some("login session already consumed".to_string()),
|
|
}))
|
|
}
|
|
LOGIN_STATUS_SUCCEEDED => {
|
|
let device_token = consume_device_token_for_poll(&mut tx, &session.session_id)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
sqlx::query(
|
|
"UPDATE desktop_login_sessions SET status = $2, consumed_at = NOW(), updated_at = NOW() WHERE session_id = $1",
|
|
)
|
|
.bind(&session.session_id)
|
|
.bind(LOGIN_STATUS_CONSUMED)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_SUCCEEDED.to_string(),
|
|
device_token: Some(device_token),
|
|
error: None,
|
|
}))
|
|
}
|
|
_ => {
|
|
tx.commit().await.map_err(internal_error)?;
|
|
Ok(Json(DesktopLoginPollResponse {
|
|
status: LOGIN_STATUS_FAILED.to_string(),
|
|
device_token: None,
|
|
error: Some("invalid login session status".to_string()),
|
|
}))
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn api_sync_pull(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
Json(payload): Json<SyncPullRequest>,
|
|
) -> std::result::Result<Json<SyncPullResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user, _) = require_auth(&state.pool, &headers).await?;
|
|
let response = sync_pull(&state.pool, user.id, payload)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
Ok(Json(response))
|
|
}
|
|
|
|
async fn api_sync_push(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
Json(payload): Json<SyncPushRequest>,
|
|
) -> std::result::Result<Json<SyncPushResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user, _) = require_auth(&state.pool, &headers).await?;
|
|
let response = sync_push(&state.pool, user.id, payload)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
Ok(Json(response))
|
|
}
|
|
|
|
async fn api_sync_object(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
Path(object_id): Path<Uuid>,
|
|
) -> std::result::Result<Json<ObjectResponse>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user, _) = require_auth(&state.pool, &headers).await?;
|
|
let object = fetch_object(&state.pool, user.id, object_id)
|
|
.await
|
|
.map_err(internal_error)?
|
|
.ok_or_else(|| unauthorized("object not found"))?;
|
|
Ok(Json(ObjectResponse { object }))
|
|
}
|
|
|
|
async fn api_devices(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
) -> std::result::Result<Json<Vec<DeviceView>>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user, _) = require_auth(&state.pool, &headers).await?;
|
|
let rows = sqlx::query_as::<_, DeviceRow>(
|
|
r#"
|
|
SELECT
|
|
d.id,
|
|
d.display_name,
|
|
d.platform,
|
|
d.client_version,
|
|
d.last_seen_at,
|
|
COALESCE(NULLIF(a.forwarded_ip, ''), NULLIF(a.ip_addr, '')) AS last_ip
|
|
FROM devices d
|
|
LEFT JOIN LATERAL (
|
|
SELECT ip_addr, forwarded_ip
|
|
FROM auth_events
|
|
WHERE device_id = d.id
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
) a ON TRUE
|
|
WHERE d.user_id = $1
|
|
ORDER BY last_seen_at DESC
|
|
"#,
|
|
)
|
|
.bind(user.id)
|
|
.fetch_all(&state.pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
let devices = rows
|
|
.into_iter()
|
|
.map(|row| DeviceView {
|
|
name: row.display_name,
|
|
platform: row.platform,
|
|
client_version: row.client_version,
|
|
last_seen: row.last_seen_at.format("%Y-%m-%d %H:%M").to_string(),
|
|
ip: row.last_ip,
|
|
})
|
|
.collect();
|
|
|
|
Ok(Json(devices))
|
|
}
|
|
|
|
async fn api_me(
|
|
State(state): State<AppState>,
|
|
headers: HeaderMap,
|
|
) -> std::result::Result<Json<UserProfileView>, (StatusCode, Json<serde_json::Value>)> {
|
|
let (user, _) = require_auth(&state.pool, &headers).await?;
|
|
Ok(Json(UserProfileView {
|
|
id: user.id,
|
|
name: user.name,
|
|
email: user.email.unwrap_or_default(),
|
|
}))
|
|
}
|
|
|
|
fn public_base_url() -> AnyResult<String> {
|
|
std::env::var("SECRETS_PUBLIC_BASE_URL")
|
|
.or_else(|_| std::env::var("SECRETS_API_BASE"))
|
|
.context("SECRETS_PUBLIC_BASE_URL or SECRETS_API_BASE must be set")
|
|
}
|
|
|
|
fn google_oauth_config() -> AnyResult<GoogleOAuthConfig> {
|
|
Ok(GoogleOAuthConfig {
|
|
client_id: std::env::var("GOOGLE_OAUTH_CLIENT_ID")
|
|
.context("GOOGLE_OAUTH_CLIENT_ID is not set")?,
|
|
client_secret: std::env::var("GOOGLE_OAUTH_CLIENT_SECRET")
|
|
.context("GOOGLE_OAUTH_CLIENT_SECRET is not set")?,
|
|
auth_uri: std::env::var("GOOGLE_OAUTH_AUTH_URI")
|
|
.unwrap_or_else(|_| "https://accounts.google.com/o/oauth2/v2/auth".to_string()),
|
|
token_uri: std::env::var("GOOGLE_OAUTH_TOKEN_URI")
|
|
.unwrap_or_else(|_| "https://oauth2.googleapis.com/token".to_string()),
|
|
redirect_uri: std::env::var("GOOGLE_OAUTH_REDIRECT_URI")
|
|
.context("GOOGLE_OAUTH_REDIRECT_URI is not set")?,
|
|
})
|
|
}
|
|
|
|
fn new_session_secret() -> String {
|
|
new_device_login_token()
|
|
}
|
|
|
|
fn pkce_challenge(verifier: &str) -> String {
|
|
let digest = Sha256::digest(verifier.as_bytes());
|
|
URL_SAFE_NO_PAD.encode(digest)
|
|
}
|
|
|
|
fn login_result_html(title: &str, message: &str) -> String {
|
|
format!(
|
|
"<html><body><h3>{}</h3><p>{}</p><p>现在可以返回 Secrets 桌面端。</p></body></html>",
|
|
title, message
|
|
)
|
|
}
|
|
|
|
fn ensure_login_session_pending(session: &DesktopLoginSessionRow) -> Result<(), &'static str> {
|
|
if session.expires_at < Utc::now() {
|
|
return Err("desktop login session expired");
|
|
}
|
|
if session.status != LOGIN_STATUS_PENDING {
|
|
return Err("desktop login session is no longer pending");
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
async fn fetch_desktop_login_session(
|
|
pool: &PgPool,
|
|
session_id: &str,
|
|
) -> AnyResult<Option<DesktopLoginSessionRow>> {
|
|
sqlx::query_as::<_, DesktopLoginSessionRow>(
|
|
r#"
|
|
SELECT
|
|
session_id, oauth_state, pkce_verifier, device_name, platform, client_version,
|
|
device_fingerprint, status, error_message, user_id, device_id, device_token,
|
|
device_token_hash, expires_at
|
|
FROM desktop_login_sessions
|
|
WHERE session_id = $1
|
|
"#,
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(pool)
|
|
.await
|
|
.context("failed to load desktop login session")
|
|
}
|
|
|
|
async fn fetch_desktop_login_session_for_update(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
session_id: &str,
|
|
) -> AnyResult<Option<DesktopLoginSessionRow>> {
|
|
sqlx::query_as::<_, DesktopLoginSessionRow>(
|
|
r#"
|
|
SELECT
|
|
session_id, oauth_state, pkce_verifier, device_name, platform, client_version,
|
|
device_fingerprint, status, error_message, user_id, device_id, device_token,
|
|
device_token_hash, expires_at
|
|
FROM desktop_login_sessions
|
|
WHERE session_id = $1
|
|
FOR UPDATE
|
|
"#,
|
|
)
|
|
.bind(session_id)
|
|
.fetch_optional(&mut **tx)
|
|
.await
|
|
.context("failed to lock desktop login session")
|
|
}
|
|
|
|
async fn fetch_desktop_login_session_by_state(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
oauth_state: &str,
|
|
) -> AnyResult<Option<DesktopLoginSessionRow>> {
|
|
sqlx::query_as::<_, DesktopLoginSessionRow>(
|
|
r#"
|
|
SELECT
|
|
session_id, oauth_state, pkce_verifier, device_name, platform, client_version,
|
|
device_fingerprint, status, error_message, user_id, device_id, device_token,
|
|
device_token_hash, expires_at
|
|
FROM desktop_login_sessions
|
|
WHERE oauth_state = $1
|
|
FOR UPDATE
|
|
"#,
|
|
)
|
|
.bind(oauth_state)
|
|
.fetch_optional(&mut **tx)
|
|
.await
|
|
.context("failed to load desktop login session by oauth state")
|
|
}
|
|
|
|
async fn mark_login_session_failed(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
session_id: &str,
|
|
message: &str,
|
|
) -> AnyResult<()> {
|
|
sqlx::query(
|
|
"UPDATE desktop_login_sessions SET status = $2, error_message = $3, updated_at = NOW() WHERE session_id = $1",
|
|
)
|
|
.bind(session_id)
|
|
.bind(LOGIN_STATUS_FAILED)
|
|
.bind(message)
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn mark_login_session_expired(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
session_id: &str,
|
|
) -> AnyResult<()> {
|
|
sqlx::query(
|
|
"UPDATE desktop_login_sessions SET status = $2, error_message = $3, updated_at = NOW() WHERE session_id = $1",
|
|
)
|
|
.bind(session_id)
|
|
.bind(LOGIN_STATUS_EXPIRED)
|
|
.bind("login session expired")
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn mark_login_session_succeeded(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
session_id: &str,
|
|
user_id: Uuid,
|
|
device_id: Uuid,
|
|
device_token: String,
|
|
device_token_hash: String,
|
|
) -> AnyResult<()> {
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE desktop_login_sessions
|
|
SET status = $2,
|
|
user_id = $3,
|
|
device_id = $4,
|
|
device_token = $5,
|
|
device_token_hash = $6,
|
|
updated_at = NOW()
|
|
WHERE session_id = $1
|
|
"#,
|
|
)
|
|
.bind(session_id)
|
|
.bind(LOGIN_STATUS_SUCCEEDED)
|
|
.bind(user_id)
|
|
.bind(device_id)
|
|
.bind(device_token)
|
|
.bind(device_token_hash)
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn exchange_google_auth_code(
|
|
http: &Client,
|
|
google: &GoogleOAuthConfig,
|
|
code: &str,
|
|
code_verifier: &str,
|
|
) -> AnyResult<GoogleTokenResponse> {
|
|
http.post(&google.token_uri)
|
|
.form(&[
|
|
("client_id", google.client_id.clone()),
|
|
("client_secret", google.client_secret.clone()),
|
|
("code", code.to_string()),
|
|
("code_verifier", code_verifier.to_string()),
|
|
("grant_type", "authorization_code".to_string()),
|
|
("redirect_uri", google.redirect_uri.clone()),
|
|
])
|
|
.send()
|
|
.await
|
|
.context("failed to exchange google auth code")?
|
|
.error_for_status()
|
|
.context("google token exchange failed")?
|
|
.json::<GoogleTokenResponse>()
|
|
.await
|
|
.context("failed to decode google token response")
|
|
}
|
|
|
|
async fn fetch_google_userinfo(http: &Client, access_token: &str) -> AnyResult<GoogleUserInfo> {
|
|
http.get("https://openidconnect.googleapis.com/v1/userinfo")
|
|
.bearer_auth(access_token)
|
|
.send()
|
|
.await
|
|
.context("failed to request google userinfo")?
|
|
.error_for_status()
|
|
.context("google userinfo request failed")?
|
|
.json::<GoogleUserInfo>()
|
|
.await
|
|
.context("failed to decode google userinfo")
|
|
}
|
|
|
|
async fn require_auth(
|
|
pool: &PgPool,
|
|
headers: &HeaderMap,
|
|
) -> std::result::Result<(UserRow, DeviceRow), (StatusCode, Json<serde_json::Value>)> {
|
|
let auth = headers
|
|
.get(header::AUTHORIZATION)
|
|
.and_then(|v| v.to_str().ok())
|
|
.and_then(|raw| raw.strip_prefix("Bearer "))
|
|
.map(str::trim)
|
|
.filter(|value| !value.is_empty())
|
|
.ok_or_else(|| unauthorized("missing bearer token"))?;
|
|
let token_hash = hash_device_login_token(auth);
|
|
|
|
let row = sqlx::query_as::<_, DeviceRow>(
|
|
r#"
|
|
SELECT
|
|
d.id,
|
|
d.display_name,
|
|
d.platform,
|
|
d.client_version,
|
|
d.last_seen_at,
|
|
NULL::text AS last_ip
|
|
FROM device_login_tokens t
|
|
JOIN devices d ON d.id = t.device_id
|
|
WHERE t.token_hash = $1
|
|
"#,
|
|
)
|
|
.bind(&token_hash)
|
|
.fetch_optional(pool)
|
|
.await
|
|
.map_err(internal_error)?
|
|
.ok_or_else(|| unauthorized("invalid device token"))?;
|
|
|
|
sqlx::query("UPDATE device_login_tokens SET last_seen_at = NOW() WHERE token_hash = $1")
|
|
.bind(&token_hash)
|
|
.execute(pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
sqlx::query("UPDATE devices SET last_seen_at = NOW() WHERE id = $1")
|
|
.bind(row.id)
|
|
.execute(pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
let user = sqlx::query_as::<_, UserRow>(
|
|
r#"
|
|
SELECT u.id, u.email, u.name
|
|
FROM users u
|
|
JOIN devices d ON d.user_id = u.id
|
|
WHERE d.id = $1
|
|
"#,
|
|
)
|
|
.bind(row.id)
|
|
.fetch_one(pool)
|
|
.await
|
|
.map_err(internal_error)?;
|
|
|
|
Ok((user, row))
|
|
}
|
|
|
|
async fn ensure_demo_user(pool: &PgPool) -> AnyResult<(Uuid, Uuid)> {
|
|
let existing =
|
|
sqlx::query_as::<_, UserRow>("SELECT id, email, name FROM users WHERE email = $1 LIMIT 1")
|
|
.bind("voson.wang.s@gmail.com")
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
let user_id = if let Some(user) = existing {
|
|
user.id
|
|
} else {
|
|
sqlx::query_scalar::<_, Uuid>(
|
|
r#"
|
|
INSERT INTO users (email, name)
|
|
VALUES ($1, $2)
|
|
RETURNING id
|
|
"#,
|
|
)
|
|
.bind("voson.wang.s@gmail.com")
|
|
.bind("Voson")
|
|
.fetch_one(pool)
|
|
.await?
|
|
};
|
|
|
|
let existing_device = sqlx::query_scalar::<_, Uuid>(
|
|
"SELECT id FROM devices WHERE user_id = $1 AND display_name = $2 LIMIT 1",
|
|
)
|
|
.bind(user_id)
|
|
.bind("Voson 的 Mac mini")
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
let device_id = if let Some(id) = existing_device {
|
|
id
|
|
} else {
|
|
sqlx::query_scalar::<_, Uuid>(
|
|
r#"
|
|
INSERT INTO devices (user_id, display_name, platform, client_version, device_fingerprint)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id
|
|
"#,
|
|
)
|
|
.bind(user_id)
|
|
.bind("Voson 的 Mac mini")
|
|
.bind("macOS")
|
|
.bind(env!("CARGO_PKG_VERSION"))
|
|
.bind(new_device_fingerprint())
|
|
.fetch_one(pool)
|
|
.await?
|
|
};
|
|
|
|
Ok((user_id, device_id))
|
|
}
|
|
|
|
async fn upsert_user_from_google(pool: &PgPool, google_user: &GoogleUserInfo) -> AnyResult<Uuid> {
|
|
let existing = sqlx::query_scalar::<_, Uuid>("SELECT id FROM users WHERE email = $1 LIMIT 1")
|
|
.bind(&google_user.email)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
if let Some(user_id) = existing {
|
|
sqlx::query("UPDATE users SET name = $1, updated_at = NOW() WHERE id = $2")
|
|
.bind(
|
|
google_user
|
|
.name
|
|
.clone()
|
|
.unwrap_or_else(|| google_user.email.clone()),
|
|
)
|
|
.bind(user_id)
|
|
.execute(pool)
|
|
.await?;
|
|
return Ok(user_id);
|
|
}
|
|
|
|
sqlx::query_scalar::<_, Uuid>(
|
|
r#"
|
|
INSERT INTO users (email, name)
|
|
VALUES ($1, $2)
|
|
RETURNING id
|
|
"#,
|
|
)
|
|
.bind(&google_user.email)
|
|
.bind(
|
|
google_user
|
|
.name
|
|
.clone()
|
|
.unwrap_or_else(|| google_user.email.clone()),
|
|
)
|
|
.fetch_one(pool)
|
|
.await
|
|
.context("failed to create user from google login")
|
|
}
|
|
|
|
async fn upsert_google_oauth_account(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
google_user: &GoogleUserInfo,
|
|
) -> AnyResult<()> {
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO oauth_accounts (user_id, provider, provider_id, email, name)
|
|
VALUES ($1, 'google', $2, $3, $4)
|
|
ON CONFLICT (provider, provider_id)
|
|
DO UPDATE SET
|
|
user_id = EXCLUDED.user_id,
|
|
email = EXCLUDED.email,
|
|
name = EXCLUDED.name
|
|
"#,
|
|
)
|
|
.bind(user_id)
|
|
.bind(&google_user.sub)
|
|
.bind(&google_user.email)
|
|
.bind(
|
|
google_user
|
|
.name
|
|
.clone()
|
|
.unwrap_or_else(|| google_user.email.clone()),
|
|
)
|
|
.execute(pool)
|
|
.await
|
|
.context("failed to upsert google oauth account")?;
|
|
Ok(())
|
|
}
|
|
|
|
async fn upsert_device_for_login(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
device_name: &str,
|
|
platform: &str,
|
|
client_version: &str,
|
|
device_fingerprint: &str,
|
|
) -> AnyResult<Uuid> {
|
|
let existing = sqlx::query_scalar::<_, Uuid>(
|
|
"SELECT id FROM devices WHERE user_id = $1 AND device_fingerprint = $2 LIMIT 1",
|
|
)
|
|
.bind(user_id)
|
|
.bind(device_fingerprint)
|
|
.fetch_optional(pool)
|
|
.await?;
|
|
|
|
if let Some(device_id) = existing {
|
|
sqlx::query(
|
|
r#"
|
|
UPDATE devices
|
|
SET display_name = $1, platform = $2, client_version = $3, last_seen_at = NOW()
|
|
WHERE id = $4
|
|
"#,
|
|
)
|
|
.bind(device_name)
|
|
.bind(platform)
|
|
.bind(client_version)
|
|
.bind(device_id)
|
|
.execute(pool)
|
|
.await?;
|
|
return Ok(device_id);
|
|
}
|
|
|
|
sqlx::query_scalar::<_, Uuid>(
|
|
r#"
|
|
INSERT INTO devices (user_id, display_name, platform, client_version, device_fingerprint)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
RETURNING id
|
|
"#,
|
|
)
|
|
.bind(user_id)
|
|
.bind(device_name)
|
|
.bind(platform)
|
|
.bind(client_version)
|
|
.bind(device_fingerprint)
|
|
.fetch_one(pool)
|
|
.await
|
|
.context("failed to create device")
|
|
}
|
|
|
|
async fn issue_device_login_token(
|
|
pool: &PgPool,
|
|
user_id: Uuid,
|
|
device_id: Uuid,
|
|
device_name: &str,
|
|
platform: &str,
|
|
client_version: &str,
|
|
) -> AnyResult<String> {
|
|
let device_token = new_device_login_token();
|
|
let token_hash = hash_device_login_token(&device_token);
|
|
|
|
sqlx::query("DELETE FROM device_login_tokens WHERE device_id = $1")
|
|
.bind(device_id)
|
|
.execute(pool)
|
|
.await?;
|
|
sqlx::query("INSERT INTO device_login_tokens (device_id, token_hash) VALUES ($1, $2)")
|
|
.bind(device_id)
|
|
.bind(token_hash)
|
|
.execute(pool)
|
|
.await?;
|
|
sqlx::query(
|
|
r#"
|
|
INSERT INTO auth_events (
|
|
user_id, device_id, device_name, platform, client_version, ip_addr, forwarded_ip, login_method, login_result
|
|
)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, 'google_desktop', 'success')
|
|
"#,
|
|
)
|
|
.bind(user_id)
|
|
.bind(device_id)
|
|
.bind(device_name)
|
|
.bind(platform)
|
|
.bind(client_version)
|
|
.bind::<Option<String>>(None)
|
|
.bind::<Option<String>>(None)
|
|
.execute(pool)
|
|
.await?;
|
|
|
|
Ok(device_token)
|
|
}
|
|
|
|
async fn consume_device_token_for_poll(
|
|
tx: &mut Transaction<'_, Postgres>,
|
|
session_id: &str,
|
|
) -> AnyResult<String> {
|
|
let token = sqlx::query_scalar::<_, Option<String>>(
|
|
"SELECT device_token FROM desktop_login_sessions WHERE session_id = $1 FOR UPDATE",
|
|
)
|
|
.bind(session_id)
|
|
.fetch_one(&mut **tx)
|
|
.await?
|
|
.context("device token already consumed")?;
|
|
|
|
sqlx::query(
|
|
"UPDATE desktop_login_sessions SET device_token = NULL, updated_at = NOW() WHERE session_id = $1",
|
|
)
|
|
.bind(session_id)
|
|
.execute(&mut **tx)
|
|
.await?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
fn internal_error<E: std::fmt::Display>(error: E) -> (StatusCode, Json<serde_json::Value>) {
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
Json(json!({ "error": error.to_string() })),
|
|
)
|
|
}
|
|
|
|
fn unauthorized(message: &str) -> (StatusCode, Json<serde_json::Value>) {
|
|
(StatusCode::UNAUTHORIZED, Json(json!({ "error": message })))
|
|
}
|