Skip to content

Commit

Permalink
🔀 Merge pull request #8 from Jumpdrive-dev/tests
Browse files Browse the repository at this point in the history
🔖 Version 1.1.0
  • Loading branch information
rster2002 authored May 13, 2023
2 parents f90b8e4 + 1b0f045 commit 6b4a801
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ I currently have no intentions to publish this to crates.io, so for now if you w
dependency using:

```toml
jumpdrive-auth = { git = "https://github.com/Jumpdrive-dev/Auth-Services", tag = "1.0.0" }
jumpdrive-auth = { git = "https://github.com/Jumpdrive-dev/Auth-Services", tag = "1.1.0" }
```

## Features
Expand Down
7 changes: 4 additions & 3 deletions src/models/jwt/jwt_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ use crate::models::jwt::jwt_token_type::JwtTokenType;
/// The header of a JWT token. Used to identify what signing algorithm is used and what type of
/// token it is.
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtHeader {
pub struct JwtHeader<T = JwtTokenType>
{
/// The algorithm of that the server used to sign the JWT token. Possible values can be found in
/// [RFC 7518](https://www.rfc-editor.org/rfc/rfc7518#section-3).
pub alg: String,
Expand All @@ -16,10 +17,10 @@ pub struct JwtHeader {

/// This is usually used when using nested JWT tokens, but here it's used to differentiate
/// between access tokens and refresh tokens.
pub cty: Option<JwtTokenType>,
pub cty: Option<T>,
}

impl Default for JwtHeader {
impl<T> Default for JwtHeader<T> {
fn default() -> Self {
Self {
alg: "RS256".to_string(),
Expand Down
57 changes: 44 additions & 13 deletions src/services/jwt_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,15 @@ impl JwtService {
}

/// Used to create a JWT token with custom a headers and claims.
pub fn create_token<T>(
pub fn create_token<T, H>(
&self,
header: JwtHeader,
header: JwtHeader<H>,
claims: JwtClaims,
payload: T,
) -> Result<String>
where
T: Serialize,
for<'a> H: Serialize + Deserialize<'a>,
{
let payload_value = self.merge_claims_with_payload(claims, payload)?;

Expand Down Expand Up @@ -167,7 +168,7 @@ impl JwtService {
for<'a> T: Deserialize<'a>,
{
let (claims, payload) = self.decode_access_token_unchecked(token)?;
JwtService::guard_claims(&claims)?;
Self::guard_claims(&claims)?;

Ok(payload)
}
Expand All @@ -181,7 +182,7 @@ impl JwtService {
where
for<'a> T: Deserialize<'a>,
{
let (header, claims, payload) = self.decode_jwt(token.into())?;
let (header, claims, payload) = self.decode_jwt::<T, JwtTokenType>(token.into())?;

let Some(token_type) = header.cty else {
return Err(JwtError::NotAnAccessToken);
Expand All @@ -200,7 +201,7 @@ impl JwtService {
for<'a> T: Deserialize<'a>,
{
let (claims, payload) = self.decode_refresh_token_unchecked(token)?;
JwtService::guard_claims(&claims)?;
Self::guard_claims(&claims)?;

Ok(payload)
}
Expand All @@ -214,7 +215,7 @@ impl JwtService {
where
for<'a> T: Deserialize<'a>,
{
let (header, claims, payload) = self.decode_jwt(token.into())?;
let (header, claims, payload) = self.decode_jwt::<T, JwtTokenType>(token.into())?;

let Some(token_type) = header.cty else {
return Err(JwtError::NotARefreshToken);
Expand All @@ -227,14 +228,44 @@ impl JwtService {
Ok((claims, payload))
}

/// Decodes a JWT token and only returns the claims.
pub fn decode_claims<H>(
&self,
token: impl Into<String>
) -> Result<JwtClaims>
where
for<'a> H: Serialize + Deserialize<'a>
{
let claims = self.decode_jwt::<(), H>(token)?.1;
Self::guard_claims(&claims)?;

Ok(claims)
}

/// Decodes a JWT token and only returns the claims. Does not perform any checks other than
/// checking the signature of the token.
pub fn decode_claims_unchecked<H>(
&self,
token: impl Into<String>,
) -> Result<JwtClaims>
where
for<'a> H: Serialize + Deserialize<'a>
{
Ok(self.decode_jwt::<(), H>(token)?.1)
}

/// Decodes the given JWT token and returns all the given important parts of the token. It
/// doesn't perform any checks apart from checking the signature. All checks should done by
/// the caller. If you created a token using either the [create_access_token] or
/// [create_refresh_token] method, make sure to use decode methods for those instead of
/// this one.
pub fn decode_jwt<T>(&self, token: impl Into<String>) -> Result<(JwtHeader, JwtClaims, T)>
pub fn decode_jwt<T, H>(
&self,
token: impl Into<String>,
) -> Result<(JwtHeader<H>, JwtClaims, T)>
where
for<'a> T: Deserialize<'a>,
for<'a> H: Serialize + Deserialize<'a>,
{
let token = token.into();

Expand Down Expand Up @@ -359,7 +390,7 @@ mod tests {
let jwt_service = create_jwt_service();

let token = jwt_service
.create_token(
.create_token::<TestPayload, JwtTokenType>(
JwtHeader::default(),
create_jwt_claims(),
TestPayload {
Expand All @@ -368,7 +399,7 @@ mod tests {
)
.unwrap();

let parts = jwt_service.decode_jwt::<TestPayload>(token).unwrap();
let parts = jwt_service.decode_jwt::<TestPayload, JwtTokenType>(token).unwrap();

assert_eq!(parts.0.typ, "JWT");
assert_eq!(parts.0.alg, "RS256");
Expand All @@ -387,14 +418,14 @@ mod tests {
};

let token = jwt_service
.create_token(
.create_token::<&TestPayload, JwtTokenType>(
JwtHeader::default(),
create_jwt_claims(),
&payload,
)
.unwrap();

let parts = jwt_service.decode_jwt::<TestPayload>(token).unwrap();
let parts = jwt_service.decode_jwt::<TestPayload, JwtTokenType>(token).unwrap();

assert_eq!(parts.0.typ, "JWT");
assert_eq!(parts.0.alg, "RS256");
Expand All @@ -417,7 +448,7 @@ mod tests {
)
.unwrap();

let parts = jwt_service.decode_jwt::<TestPayload>(token).unwrap();
let parts = jwt_service.decode_jwt::<TestPayload, JwtTokenType>(token).unwrap();

assert_eq!(parts.0.typ, "JWT");
assert_eq!(parts.0.alg, "RS256");
Expand All @@ -440,7 +471,7 @@ mod tests {
)
.unwrap();

let parts = jwt_service.decode_jwt::<TestPayload>(token).unwrap();
let parts = jwt_service.decode_jwt::<TestPayload, JwtTokenType>(token).unwrap();

assert_eq!(parts.0.typ, "JWT");
assert_eq!(parts.0.alg, "RS256");
Expand Down
40 changes: 34 additions & 6 deletions src/services/totp_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl TotpService {
) -> Result<bool, TotpError> {
let code = code.into();
let secret_key = secret_key.into();
let current_step = Self::get_time_step();
let current_step = Self::get_current_time_step();

let last = Self::generate_code_with_step(&secret_key, current_step - 1)?;

Expand Down Expand Up @@ -76,11 +76,39 @@ impl TotpService {
Ok(format!("{:0>6}", code))
}

fn get_time_step() -> u64 {
fn get_current_time_step() -> u64 {
(Utc::now().timestamp() / 30) as u64
}
}

#[cfg(test)]
impl TotpService {
/// Only available withing a test environment. Generates a 2FA code with the given secret key
/// and time step. To generate a code using the current time step, use
/// [TotpService::test_generate_current_code].
pub fn test_generate_code_with_step(
secret_key: impl Into<String>,
step: u64,
) -> Result<String, TotpError> {
Self::generate_code_with_step(
secret_key,
step
)
}

/// Only available within a test environment. Generates a 2FA code with the given secret key and
/// uses the current time step. To generate a code with a specific time step, use
/// [TotpService::test_generate_code_with_step].
pub fn test_generate_current_code(
secret_key: impl Into<String>,
) -> Result<String, TotpError> {
Self::generate_code_with_step(
secret_key,
Self::get_current_time_step(),
)
}
}

#[cfg(test)]
mod tests {

Expand Down Expand Up @@ -110,7 +138,7 @@ mod tests {
#[test]
fn correct_code_can_be_validated() {
let code =
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_time_step()).unwrap();
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_current_time_step()).unwrap();

let valid = TotpService::validate_code(SECRET_KEY, code).unwrap();

Expand All @@ -120,7 +148,7 @@ mod tests {
#[test]
fn code_from_previous_step_is_still_valid() {
let code =
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_time_step() - 1)
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_current_time_step() - 1)
.unwrap();

let valid = TotpService::validate_code(SECRET_KEY, code).unwrap();
Expand All @@ -131,7 +159,7 @@ mod tests {
#[test]
fn expired_code_is_invalid() {
let code =
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_time_step() - 10)
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_current_time_step() - 10)
.unwrap();

let valid = TotpService::validate_code(SECRET_KEY, code).unwrap();
Expand All @@ -142,7 +170,7 @@ mod tests {
#[test]
fn future_code_is_invalid() {
let code =
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_time_step() + 1)
TotpService::generate_code_with_step(SECRET_KEY, TotpService::get_current_time_step() + 1)
.unwrap();

let valid = TotpService::validate_code(SECRET_KEY, code).unwrap();
Expand Down

0 comments on commit 6b4a801

Please sign in to comment.