From 197c5da40cb6bbb02ddbac3f897f37e0fd3ecae2 Mon Sep 17 00:00:00 2001 From: nullchinchilla Date: Thu, 9 May 2024 16:01:07 -0400 Subject: [PATCH] Update auth token validation to return account level and handle subscription check --- binaries/geph5-broker/src/auth.rs | 20 +++++++++++++++----- binaries/geph5-broker/src/rpc_impl.rs | 11 ++++++++--- libraries/geph5-broker-protocol/src/lib.rs | 2 +- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/binaries/geph5-broker/src/auth.rs b/binaries/geph5-broker/src/auth.rs index 9dcf6f7..f55c2d2 100644 --- a/binaries/geph5-broker/src/auth.rs +++ b/binaries/geph5-broker/src/auth.rs @@ -1,9 +1,10 @@ use std::ops::Deref as _; use argon2::{password_hash::Encoding, Argon2, PasswordHash, PasswordVerifier}; -use geph5_broker_protocol::AuthError; +use geph5_broker_protocol::{AccountLevel, AuthError}; use rand::Rng as _; +use tracing::Level; use crate::{database::POSTGRES, log_error}; @@ -51,11 +52,20 @@ pub async fn new_auth_token(user_id: i32) -> anyhow::Result { } } -pub async fn valid_auth_token(token: &str) -> anyhow::Result { - let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM auth_tokens WHERE token = $1") +pub async fn valid_auth_token(token: &str) -> anyhow::Result> { + let user_id: Option<(i32, i64)> = sqlx::query_as("SELECT user_id, (select count(*) from subscriptions where id = user_id) FROM auth_tokens WHERE token = $1") .bind(token) - .fetch_one(POSTGRES.deref()) + .fetch_optional(POSTGRES.deref()) .await?; - Ok(row.0 > 0) + if let Some((user_id, is_plus)) = user_id { + tracing::debug!(user_id, is_plus, "valid auth token"); + if is_plus == 0 { + Ok(Some(AccountLevel::Free)) + } else { + Ok(Some(AccountLevel::Plus)) + } + } else { + Ok(None) + } } diff --git a/binaries/geph5-broker/src/rpc_impl.rs b/binaries/geph5-broker/src/rpc_impl.rs index b2f3b83..ddd4c27 100644 --- a/binaries/geph5-broker/src/rpc_impl.rs +++ b/binaries/geph5-broker/src/rpc_impl.rs @@ -65,9 +65,11 @@ impl BrokerProtocol for BrokerImpl { epoch: u16, blind_token: BlindedClientToken, ) -> Result { - match valid_auth_token(&auth_token).await { + let user_level = match valid_auth_token(&auth_token).await { Ok(auth) => { - if !auth { + if let Some(level) = auth { + level + } else { return Err(AuthError::Forbidden); } } @@ -75,8 +77,11 @@ impl BrokerProtocol for BrokerImpl { tracing::warn!(err = debug(err), "database failed"); return Err(AuthError::RateLimited); } - } + }; let start = Instant::now(); + if user_level != level { + return Err(AuthError::WrongLevel); + } let signed = match level { AccountLevel::Free => &FREE_MIZARU_SK, AccountLevel::Plus => &PLUS_MIZARU_SK, diff --git a/libraries/geph5-broker-protocol/src/lib.rs b/libraries/geph5-broker-protocol/src/lib.rs index d5bb36e..5555826 100644 --- a/libraries/geph5-broker-protocol/src/lib.rs +++ b/libraries/geph5-broker-protocol/src/lib.rs @@ -47,7 +47,7 @@ pub trait BrokerProtocol { async fn incr_stat(&self, stat: String, value: i32); } -#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum AccountLevel { Free, Plus,