Skip to content

Commit

Permalink
Merge pull request #497 from sebadob/code-query-cleanup
Browse files Browse the repository at this point in the history
code and query cleanup
  • Loading branch information
sebadob authored Jun 25, 2024
2 parents c30285f + c2b9403 commit 871a705
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 112 deletions.
34 changes: 0 additions & 34 deletions src/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,6 @@ const B64_URL_SAFE: engine::GeneralPurpose = general_purpose::URL_SAFE;
const B64_URL_SAFE_NO_PAD: engine::GeneralPurpose = general_purpose::URL_SAFE_NO_PAD;
const B64_STD: engine::GeneralPurpose = general_purpose::STANDARD;

// Returns the cache key for a given client
pub fn cache_entry_client(id: &str) -> String {
format!("client_{}", id)
}

// Converts a given Json array / list into a Vec<String>
pub fn json_arr_to_vec(arr: &str) -> Vec<String> {
arr.chars()
.skip(1)
.filter(|&c| c != '"')
// TODO improve -> array inside array would not work
.take_while(|&c| c != ']')
.collect::<String>()
.split(',')
.map(|i| i.to_string())
.collect()
}

pub fn get_local_hostname() -> String {
let hostname_os = gethostname();
hostname_os
Expand Down Expand Up @@ -209,22 +191,6 @@ fn ip_from_cust_header(headers: &HeaderMap) -> Option<IpAddr> {
mod tests {
use super::*;
use pretty_assertions::assert_eq;
use std::string::String;

#[test]
fn test_json_arr_to_vec() {
let arr = String::from("[\"one\",\"two\",\"three\"]");
let arr_as_vec = vec!["one", "two", "three"];
assert_eq!(json_arr_to_vec(&arr), arr_as_vec);

let arr = String::from("[\"one\"]");
let arr_as_vec = vec!["one"];
assert_eq!(json_arr_to_vec(&arr), arr_as_vec);

let arr = String::from("[]");
let arr_as_vec = vec![""];
assert_eq!(json_arr_to_vec(&arr), arr_as_vec);
}

#[test]
fn test_get_rand() {
Expand Down
4 changes: 2 additions & 2 deletions src/models/src/entity/clients.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use rauthy_common::constants::{
EPHEMERAL_CLIENTS_ALLOWED_FLOWS, EPHEMERAL_CLIENTS_ALLOWED_SCOPES, EPHEMERAL_CLIENTS_FORCE_MFA,
IDX_CLIENTS, PROXY_MODE, RAUTHY_VERSION,
};
use rauthy_common::utils::{cache_entry_client, get_rand, real_ip_from_req};
use rauthy_common::utils::{get_rand, real_ip_from_req};
use rauthy_error::{ErrorResponse, ErrorResponseType};
use redhac::{
cache_get, cache_get_from, cache_get_value, cache_insert, cache_put, cache_remove, AckLevel,
Expand Down Expand Up @@ -225,7 +225,7 @@ impl Client {
let client = cache_get!(
Client,
CACHE_NAME_12HR.to_string(),
cache_entry_client(&id),
Client::get_cache_entry(&id),
&data.caches.ha_cache_config,
false
)
Expand Down
10 changes: 5 additions & 5 deletions src/models/src/entity/colors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub struct ColorEntity {
// CRUD
impl ColorEntity {
pub async fn delete(data: &web::Data<AppState>, client_id: &str) -> Result<(), ErrorResponse> {
sqlx::query!("delete from colors where client_id = $1", client_id,)
sqlx::query!("DELETE FROM colors WHERE client_id = $1", client_id,)
.execute(&data.db)
.await?;

Expand Down Expand Up @@ -48,7 +48,7 @@ impl ColorEntity {
return Ok(colors);
}

let res = sqlx::query_as!(Self, "select * from colors where client_id = $1", client_id)
let res = sqlx::query_as!(Self, "SELECT * FROM colors WHERE client_id = $1", client_id)
.fetch_optional(&data.db)
.await?;
let colors = match res {
Expand Down Expand Up @@ -81,14 +81,14 @@ impl ColorEntity {

#[cfg(not(feature = "postgres"))]
let q = sqlx::query!(
"insert or replace into colors (client_id, data) values ($1, $2)",
"INSERT OR REPLACE INTO colors (client_id, data) values ($1, $2)",
client_id,
col_bytes,
);
#[cfg(feature = "postgres")]
let q = sqlx::query!(
r#"insert into colors (client_id, data) values ($1, $2)
on conflict(client_id) do update set data = $2"#,
r#"INSERT INTO colors (client_id, data) values ($1, $2)
ON CONFLICT(client_id) DO UPDATE SET data = $2"#,
client_id,
col_bytes,
);
Expand Down
19 changes: 10 additions & 9 deletions src/models/src/entity/db_version.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct DbVersion {

impl DbVersion {
pub async fn find(db: &DbPool) -> Option<Self> {
let res = query!("select data from config where id = 'db_version'")
let res = query!("SELECT data FROM config WHERE id = 'db_version'")
.fetch_optional(db)
.await
.ok()?;
Expand All @@ -39,13 +39,14 @@ impl DbVersion {

#[cfg(not(feature = "postgres"))]
let q = query!(
"insert or replace into config (id, data) values ('db_version', $1)",
"INSERT OR REPLACE INTO config (id, data) VALUES ('db_version', $1)",
data,
);
#[cfg(feature = "postgres")]
let q = query!(
r#"insert into config (id, data) values ('db_version', $1)
on conflict(id) do update set data = $1"#,
r#"INSERT INTO config (id, data)
VALUES ('db_version', $1)
ON CONFLICT(id) DO UPDATE SET data = $1"#,
data,
);
q.execute(db).await?;
Expand All @@ -59,7 +60,7 @@ impl DbVersion {
debug!("Current Rauthy Version: {:?}", app_version);

// check DB version for compatibility
let db_exists = query!("select id from config limit 1")
let db_exists = query!("SELECT id FROM config LIMIT 1")
.fetch_one(db)
.await
.is_ok();
Expand Down Expand Up @@ -133,13 +134,13 @@ impl DbVersion {

// the passkeys table was introduced with v0.15.0
#[cfg(feature = "postgres")]
let is_db_v0_15_0 = query!("select * from pg_tables where tablename = 'passkeys' limit 1")
let is_db_v0_15_0 = query!("SELECT * FROM pg_tables WHERE tablename = 'passkeys' LIMIT 1")
.fetch_one(db)
.await
.is_err();
#[cfg(not(feature = "postgres"))]
let is_db_v0_15_0 = query!(
"select * from sqlite_master where type = 'table' and name = 'passkeys' limit 1"
"SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'passkeys' LIMIT 1"
)
.fetch_one(db)
.await
Expand All @@ -155,13 +156,13 @@ impl DbVersion {
// which is there since the very beginning.
#[cfg(feature = "postgres")]
let is_db_pre_v0_15_0 =
query!("select * from pg_tables where tablename = 'clients' limit 1")
query!("SELECT * FROM pg_tables WHERE tablename = 'clients' LIMIT 1")
.fetch_one(db)
.await
.is_err();
#[cfg(not(feature = "postgres"))]
let is_db_pre_v0_15_0 =
query!("select * from sqlite_master where type = 'table' and name = 'clients' limit 1")
query!("SELECT * FROM sqlite_master WHERE type = 'table' AND name = 'clients' LIMIT 1")
.fetch_one(db)
.await
.is_err();
Expand Down
10 changes: 5 additions & 5 deletions src/models/src/entity/groups.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl Group {
};

sqlx::query!(
"insert into groups (id, name) values ($1, $2)",
"INSERT INTO groups (id, name) VALUES ($1, $2)",
new_group.id,
new_group.name,
)
Expand Down Expand Up @@ -93,7 +93,7 @@ impl Group {
user.save(data, None, Some(&mut txn)).await?;
}

sqlx::query!("delete from groups where id = $1", group.id)
sqlx::query!("DELETE FROM groups WHERE id = $1", group.id)
.execute(&mut *txn)
.await?;

Expand All @@ -118,7 +118,7 @@ impl Group {

// Returns a single group by id
pub async fn find(data: &web::Data<AppState>, id: String) -> Result<Self, ErrorResponse> {
let res = sqlx::query_as!(Self, "select * from groups where id = $1", id,)
let res = sqlx::query_as!(Self, "SELECT * FROM groups WHERE id = $1", id,)
.fetch_one(&data.db)
.await?;

Expand All @@ -139,7 +139,7 @@ impl Group {
return Ok(groups);
}

let res = sqlx::query_as!(Self, "select * from groups")
let res = sqlx::query_as!(Self, "SELECT * FROM groups")
.fetch_all(&data.db)
.await?;

Expand Down Expand Up @@ -197,7 +197,7 @@ impl Group {
};

sqlx::query!(
"update groups set name = $1 where id = $2",
"UPDATE groups SET name = $1 WHERE id = $2",
new_group.name,
new_group.id,
)
Expand Down
10 changes: 5 additions & 5 deletions src/models/src/entity/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ impl Jwk {
pub async fn save(&self, db: &DbPool) -> Result<(), ErrorResponse> {
let sig_str = self.signature.as_str();
sqlx::query!(
r#"insert into jwks (kid, created_at, signature, enc_key_id, jwk)
values ($1, $2, $3, $4, $5)"#,
r#"INSERT INTO jwks (kid, created_at, signature, enc_key_id, jwk)
VALUES ($1, $2, $3, $4, $5)"#,
self.kid,
self.created_at,
sig_str,
Expand Down Expand Up @@ -164,7 +164,7 @@ impl JWKS {
return Ok(jwks);
}

let res = sqlx::query_as!(Jwk, "select * from jwks")
let res = sqlx::query_as!(Jwk, "SELECT * FROM jwks")
.fetch_all(&data.db)
.await?;

Expand Down Expand Up @@ -644,7 +644,7 @@ impl JwkKeyPair {
return Ok(jwk_opt);
}

let jwk = sqlx::query_as!(Jwk, "select * from jwks where kid = $1", kid,)
let jwk = sqlx::query_as!(Jwk, "SELECT * FROM jwks WHERE kid = $1", kid,)
.fetch_one(&data.db)
.await?;

Expand Down Expand Up @@ -680,7 +680,7 @@ impl JwkKeyPair {
return Ok(jwk_opt);
}

let mut jwks = sqlx::query_as!(Jwk, "select * from jwks")
let mut jwks = sqlx::query_as!(Jwk, "SELECT * FROM jwks")
.fetch_all(&data.db)
.await?
.into_iter()
Expand Down
12 changes: 6 additions & 6 deletions src/models/src/entity/magic_links.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ impl MagicLink {
};

sqlx::query!(
r#"insert into magic_links (id, user_id, csrf_token, exp, used, usage)
values ($1, $2, $3, $4, $5, $6)"#,
r#"INSERT INTO magic_links (id, user_id, csrf_token, exp, used, usage)
VALUES ($1, $2, $3, $4, $5, $6)"#,
link.id,
link.user_id,
link.csrf_token,
Expand All @@ -131,7 +131,7 @@ impl MagicLink {
}

pub async fn find(data: &web::Data<AppState>, id: &str) -> Result<Self, ErrorResponse> {
let res = sqlx::query_as!(Self, "select * from magic_links where id = $1", id)
let res = sqlx::query_as!(Self, "SELECT * FROM magic_links WHERE id = $1", id)
.fetch_one(&data.db)
.await?;

Expand All @@ -144,7 +144,7 @@ impl MagicLink {
) -> Result<MagicLink, ErrorResponse> {
let res = sqlx::query_as!(
Self,
"select * from magic_links where user_id = $1",
"SELECT * FROM magic_links WHERE user_id = $1",
user_id
)
.fetch_one(&data.db)
Expand All @@ -158,7 +158,7 @@ impl MagicLink {
user_id: &str,
) -> Result<(), ErrorResponse> {
sqlx::query!(
"delete from magic_links where user_id = $1 and usage like 'email_change$%'",
"DELETE FROM magic_links WHERE user_id = $1 AND USAGE LIKE 'email_change$%'",
user_id,
)
.execute(&data.db)
Expand All @@ -169,7 +169,7 @@ impl MagicLink {

pub async fn save(&self, data: &web::Data<AppState>) -> Result<(), ErrorResponse> {
sqlx::query!(
"update magic_links set cookie = $1, exp = $2, used = $3 where id = $4",
"UPDATE magic_links SET cookie = $1, exp = $2, used = $3 WHERE id = $4",
self.cookie,
self.exp,
self.used,
Expand Down
10 changes: 5 additions & 5 deletions src/models/src/entity/password.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ impl PasswordPolicy {
return Ok(policy);
}

let res = sqlx::query("select data from config where id = 'password_policy'")
let res = sqlx::query("SELECT data FROM config WHERE id = 'password_policy'")
.fetch_one(&data.db)
.await?;
let bytes: Vec<u8> = res.get("data");
Expand All @@ -158,7 +158,7 @@ impl PasswordPolicy {
let slf = bincode::serialize(&self).unwrap();

sqlx::query!(
"update config set data = $1 where id = 'password_policy'",
"UPDATE config SET data = $1 WHERE id = 'password_policy'",
slf
)
.execute(&data.db)
Expand Down Expand Up @@ -219,7 +219,7 @@ impl RecentPasswordsEntity {
passwords: &String,
) -> Result<(), ErrorResponse> {
sqlx::query!(
"insert into recent_passwords (user_id, passwords) values ($1, $2)",
"INSERT INTO recent_passwords (user_id, passwords) VALUES ($1, $2)",
user_id,
passwords,
)
Expand All @@ -232,7 +232,7 @@ impl RecentPasswordsEntity {
pub async fn find(data: &web::Data<AppState>, user_id: &str) -> Result<Self, ErrorResponse> {
let res = sqlx::query_as!(
Self,
"select * from recent_passwords where user_id = $1",
"SELECT * FROM recent_passwords WHERE user_id = $1",
user_id,
)
.fetch_one(&data.db)
Expand All @@ -242,7 +242,7 @@ impl RecentPasswordsEntity {

pub async fn save(&self, data: &web::Data<AppState>) -> Result<(), ErrorResponse> {
sqlx::query!(
"update recent_passwords set passwords = $1 where user_id = $2",
"UPDATE recent_passwords SET passwords = $1 WHERE user_id = $2",
self.passwords,
self.user_id,
)
Expand Down
10 changes: 5 additions & 5 deletions src/models/src/entity/roles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Role {
name: role_req.role,
};
sqlx::query!(
"insert into roles (id, name) values ($1, $2)",
"INSERT INTO roles (id, name) VALUES ($1, $2)",
new_role.id,
new_role.name,
)
Expand Down Expand Up @@ -100,7 +100,7 @@ impl Role {
user.save(data, None, Some(&mut txn)).await?;
}

sqlx::query!("delete from roles where id = $1", id)
sqlx::query!("DELETE FROM roles WHERE id = $1", id)
.execute(&mut *txn)
.await?;

Expand All @@ -126,7 +126,7 @@ impl Role {

// Returns a single role by id
pub async fn find(data: &web::Data<AppState>, id: &str) -> Result<Self, ErrorResponse> {
let res = sqlx::query_as!(Self, "select * from roles where id = $1", id)
let res = sqlx::query_as!(Self, "SELECT * FROM roles WHERE id = $1", id)
.fetch_one(&data.db)
.await?;

Expand All @@ -147,7 +147,7 @@ impl Role {
return Ok(roles);
}

let res = sqlx::query_as!(Self, "select * from roles")
let res = sqlx::query_as!(Self, "SELECT * FROM roles")
.fetch_all(&data.db)
.await?;

Expand Down Expand Up @@ -200,7 +200,7 @@ impl Role {

let new_role = Role { id, name: new_name };
sqlx::query!(
"update roles set name = $1 where id = $2",
"UPDATE roles SET name = $1 WHERE id = $2",
new_role.name,
new_role.id,
)
Expand Down
Loading

0 comments on commit 871a705

Please sign in to comment.