Skip to content

Commit

Permalink
feat: separate view-tokens from push-tokens
Browse files Browse the repository at this point in the history
ssaavedra committed Aug 27, 2024
1 parent 4d785ea commit 574f66b
Showing 9 changed files with 227 additions and 53 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions migrations/0006_view_only_tokens.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Add down migration script here

DROP INDEX IF EXISTS idx_view_tokens_token;
DROP TABLE view_tokens;
19 changes: 19 additions & 0 deletions migrations/0006_view_only_tokens.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
-- Add up migration script here

CREATE TABLE view_tokens (
id SERIAL PRIMARY KEY,
token TEXT NOT NULL,
user_id INT NOT NULL,
view_token_valid_until TIMESTAMP NULL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
last_accessed_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,

FOREIGN KEY (user_id) REFERENCES users (id)
);

CREATE INDEX IF NOT EXISTS idx_view_tokens_token ON view_tokens (token);

-- Insert all existing tokens also as view tokens
INSERT INTO view_tokens (token, user_id, view_token_valid_until)
SELECT token, user_id, datetime('now', '+60 years') as view_token_valid_until
FROM tokens;
15 changes: 11 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ use rocket::serde::{json::Json, Deserialize};
use rocket::{catchers, fairing, get, launch, post, routes};
use rocket_db_pools::{sqlx, Connection, Database};
use rocket_governor::{rocket_governor_catcher, RocketGovernable, RocketGovernor};
use token::{Token, ValidDbToken};
use token::{Token, ValidDbToken, ValidViewToken};

mod alive_check;
mod car;
@@ -148,6 +148,13 @@ async fn post_token(
format!("OK")
}

#[get("/log/<_>/check")]
async fn check_token_valid(
token: &ValidDbToken,
) -> String {
format!("Token {} is valid", token.simplified())
}

/// Route GET /log/:token/html will return the data in HTML format
#[get("/log/<_>/html?<page>&<count>&<start>&<end>&<interval>&<tz>", rank = 1)]
async fn list_table_html(
@@ -157,7 +164,7 @@ async fn list_table_html(
end: HtmlInputParseableDateTime,
interval: Option<i32>,
tz: form::Tz,
token: &ValidDbToken,
token: &ValidViewToken,
mut db: Connection<Logs>,
_ratelimit: RocketGovernor<'_, RateLimitGuard>,
) -> (ContentType, String) {
@@ -249,7 +256,7 @@ async fn list_table_json(
end: HtmlInputParseableDateTime,
interval: Option<i32>,
tz: form::Tz,
token: &ValidDbToken,
token: &ValidViewToken,
mut db: Connection<Logs>,
_ratelimit: RocketGovernor<'_, RateLimitGuard>,
) -> rocket::response::content::RawJson<String> {
@@ -291,7 +298,7 @@ async fn list_table_svg(
end: HtmlInputParseableDateTime,
interval: Option<i32>,
tz: form::Tz,
token: &ValidDbToken,
token: &ValidViewToken,
mut db: Connection<Logs>,
_ratelimit: RocketGovernor<'_, RateLimitGuard>,
) -> (ContentType, String) {
20 changes: 12 additions & 8 deletions src/print_table.rs
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ use serde::Serialize;

use crate::{
form::HtmlInputParseableDateTime,
token::{DbToken, Token, ValidDbToken},
token::{DbToken, Token, ValidViewToken},
};

pub struct Pagination {
@@ -142,7 +142,7 @@ impl RowInfo {
/// rows to be fetched.
pub async fn get_paginated_rows_for_token(
db: &mut Connection<crate::Logs>,
token: &ValidDbToken,
token: &ValidViewToken,
pagination: &PaginationResult,
tz: &chrono_tz::Tz,
) -> (Vec<RowInfo>, bool) {
@@ -162,13 +162,15 @@ pub async fn get_paginated_rows_for_token(
let end = end.format("%Y-%m-%d %H:%M:%S").to_string();

let db_rows = sqlx::query!(
"SELECT amps, volts, watts, created_at, user_agent, client_ip, energy_log.token as token, u.location as location
"SELECT amps, volts, watts, energy_log.created_at as created_at, user_agent, client_ip, energy_log.token as token, u.location as location
FROM energy_log
INNER JOIN tokens t
ON t.token = energy_log.token
INNER JOIN users u
ON u.id = t.user_id
WHERE energy_log.token = ?
INNER JOIN view_tokens vt
ON vt.user_id = u.id
WHERE vt.token = ?
AND energy_log.created_at BETWEEN ? AND ?
ORDER BY created_at DESC
LIMIT ?
@@ -218,7 +220,7 @@ pub async fn get_paginated_rows_for_token(
/// interval passed as a parameter.
pub async fn get_avg_max_rows_for_token<Tz: chrono::TimeZone>(
db: &mut Connection<crate::Logs>,
token: &ValidDbToken,
token: &ValidViewToken,
start: &DateTime<Tz>,
end: &DateTime<Tz>,
interval: i32,
@@ -229,14 +231,16 @@ pub async fn get_avg_max_rows_for_token<Tz: chrono::TimeZone>(
let end = end.naive_utc();

let db_rows = sqlx::query!(
"SELECT AVG(amps) as amps, MAX(amps) as max_amps, AVG(volts) as volts, AVG(watts) as watts, MAX(watts) as max_watts, created_at, user_agent, client_ip, energy_log.token as token, u.location as location
"SELECT AVG(amps) as amps, MAX(amps) as max_amps, AVG(volts) as volts, AVG(watts) as watts, MAX(watts) as max_watts, energy_log.created_at as created_at, user_agent, client_ip, energy_log.token as token, u.location as location
FROM energy_log
INNER JOIN tokens t
ON t.token = energy_log.token
INNER JOIN users u
ON u.id = t.user_id
WHERE energy_log.token = ? AND created_at BETWEEN ? AND ?
GROUP BY strftime('%s', created_at) / ?
INNER JOIN view_tokens vt
ON vt.user_id = u.id
WHERE vt.token = ? AND energy_log.created_at BETWEEN ? AND ?
GROUP BY strftime('%s', energy_log.created_at) / ?
ORDER BY created_at DESC",
token,
start,
182 changes: 145 additions & 37 deletions src/token.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

use rocket_db_pools::Connection;
use sqlx::{Encode, Type};

@@ -20,9 +19,8 @@ impl Token for DbToken {
}
}


/// This struct is used to store the token that is passed in the URL.
///
///
/// The second argument is a private unit struct, which is used to statically
/// ensure that the token can only be created by its `FromRequest`
/// implementation.
@@ -34,7 +32,10 @@ impl Token for ValidDbToken {
}
}

impl<DB: sqlx::Database> Type<DB> for DbToken where std::string::String: Type<DB> {
impl<DB: sqlx::Database> Type<DB> for DbToken
where
std::string::String: Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
<String as Type<DB>>::type_info()
}
@@ -43,7 +44,10 @@ impl<DB: sqlx::Database> Type<DB> for DbToken where std::string::String: Type<DB
}
}

impl<DB: sqlx::Database> Type<DB> for ValidDbToken where std::string::String: Type<DB> {
impl<DB: sqlx::Database> Type<DB> for ValidDbToken
where
std::string::String: Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
<String as Type<DB>>::type_info()
}
@@ -52,20 +56,30 @@ impl<DB: sqlx::Database> Type<DB> for ValidDbToken where std::string::String: Ty
}
}

impl<'a, DB: sqlx::Database> Encode<'a, DB> for DbToken where std::string::String: Encode<'a, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'a>>::ArgumentBuffer) -> sqlx::encode::IsNull {
impl<'a, DB: sqlx::Database> Encode<'a, DB> for DbToken
where
std::string::String: Encode<'a, DB>,
{
fn encode_by_ref(
&self,
buf: &mut <DB as sqlx::database::HasArguments<'a>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode_by_ref(buf)
}
}

impl<'a, DB: sqlx::Database> Encode<'a, DB> for ValidDbToken where std::string::String: Encode<'a, DB> {
fn encode_by_ref(&self, buf: &mut <DB as sqlx::database::HasArguments<'a>>::ArgumentBuffer) -> sqlx::encode::IsNull {
impl<'a, DB: sqlx::Database> Encode<'a, DB> for ValidDbToken
where
std::string::String: Encode<'a, DB>,
{
fn encode_by_ref(
&self,
buf: &mut <DB as sqlx::database::HasArguments<'a>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode_by_ref(buf)
}
}



impl std::fmt::Display for DbToken {
/// User-facing display of the token, showing only the first and last 4
/// characters.
@@ -80,12 +94,51 @@ impl std::fmt::Display for ValidDbToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}

}

enum RequestTokenDbResult {
Ok(ValidDbToken),
NotFound,
/// This struct is used to store a view-token passed in the URL.
///
/// The second argument is a private unit struct, which is used to statically
/// ensure that the token can only be created by its `FromRequest`
/// implementation.
pub struct ValidViewToken(pub DbToken, ());

impl Token for ValidViewToken {
fn full_token<'a>(&'a self) -> &'a str {
self.0.full_token()
}
}

impl<DB: sqlx::Database> Type<DB> for ValidViewToken
where
std::string::String: Type<DB>,
{
fn type_info() -> <DB as sqlx::Database>::TypeInfo {
<String as Type<DB>>::type_info()
}
fn compatible(ty: &<DB as sqlx::Database>::TypeInfo) -> bool {
<String as Type<DB>>::compatible(ty)
}
}

impl<'a, DB: sqlx::Database> Encode<'a, DB> for ValidViewToken
where
std::string::String: Encode<'a, DB>,
{
fn encode_by_ref(
&self,
buf: &mut <DB as sqlx::database::HasArguments<'a>>::ArgumentBuffer,
) -> sqlx::encode::IsNull {
self.0.encode_by_ref(buf)
}
}

impl std::fmt::Display for ValidViewToken {
/// User-facing display of the token, showing only the first and last 4
/// characters.
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<{}>", self.0.simplified())
}
}

/// This function returns a cleaned up version of the token, showing only the
@@ -105,32 +158,87 @@ impl<'r> rocket::request::FromRequest<'r> for &'r ValidDbToken {
async fn from_request(
request: &'r rocket::Request<'_>,
) -> rocket::request::Outcome<Self, Self::Error> {
let result = request.local_cache_async(async {
let mut db = request.guard::<Connection<crate::Logs>>().await.expect("Failed to get db connection");
let token = request.routed_segment(1).map(|s| s.to_string());
match token {
Some(token) => {
let rows = sqlx::query!(
"SELECT COUNT(*) as count FROM tokens WHERE token = ?",
token
);
let count = rows.fetch_one(&mut **db).await.unwrap().count;
log::info!("Token count in DB: {}", count);
if count == 0 {
return RequestTokenDbResult::NotFound;
let result = request
.local_cache_async(async {
let mut db = request
.guard::<Connection<crate::Logs>>()
.await
.expect("Failed to get db connection");
let token = request.routed_segment(1).map(|s| s.to_string());
match token {
Some(token) => {
let rows = sqlx::query!(
"SELECT COUNT(*) as count FROM tokens WHERE token = ?",
token
);
let count = rows.fetch_one(&mut **db).await.unwrap().count;
log::info!("Token count in DB: {}", count);
if count == 0 {
return None;
}
Some(ValidDbToken(DbToken(token), ()))
}
_ => {
log::info!("No token found");
None
}
RequestTokenDbResult::Ok(ValidDbToken(DbToken(token), ()))
}
_ => {
log::info!("No token found");
RequestTokenDbResult::NotFound
})
.await;

match result {
Some(token) => rocket::request::Outcome::Success(token),
None => rocket::request::Outcome::Forward(rocket::http::Status::NotFound),
}
}
}


#[rocket::async_trait]
impl<'r> rocket::request::FromRequest<'r> for &'r ValidViewToken {
type Error = ();

async fn from_request(
request: &'r rocket::Request<'_>,
) -> rocket::request::Outcome<Self, Self::Error> {
let result = request
.local_cache_async(async {
let mut db = request
.guard::<Connection<crate::Logs>>()
.await
.expect("Failed to get db connection");
let token = request.routed_segment(1).map(|s| s.to_string());
match token {
Some(token) => {
let rows = sqlx::query!(
"SELECT COUNT(*) as count FROM view_tokens WHERE token = ? AND (view_token_valid_until is null OR view_token_valid_until > datetime(\"NOW\"))",
token
);
let count = rows.fetch_one(&mut **db).await.unwrap().count;
log::info!("Token count in DB: {}", count);
if count == 0 {
return None;
}
let now = chrono::Utc::now().format("%Y-%m-%d %H:%M:%S").to_string();
// Update last accessed time
sqlx::query!(
"UPDATE view_tokens SET last_accessed_at = ? WHERE token = ?",
now,
token
).execute(&mut **db).await.unwrap();
Some(ValidViewToken(DbToken(token), ()))
}
_ => {
log::info!("No token found");
None
}
}
}
}).await;
})
.await;

match result {
RequestTokenDbResult::Ok(token) => rocket::request::Outcome::Success(token),
RequestTokenDbResult::NotFound => rocket::request::Outcome::Forward(rocket::http::Status::NotFound),
Some(token) => rocket::request::Outcome::Success(token),
None => rocket::request::Outcome::Forward(rocket::http::Status::NotFound),
}
}
}
}

0 comments on commit 574f66b

Please sign in to comment.