-
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1f8338f
commit 60bc7de
Showing
5 changed files
with
284 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
use std::net::SocketAddr; | ||
|
||
use axum::{ | ||
error_handling::HandleErrorLayer, response::IntoResponse, routing::get, BoxError, Router, | ||
}; | ||
use http::StatusCode; | ||
use serde::{Deserialize, Serialize}; | ||
use tower::ServiceBuilder; | ||
use tower_sessions::{sqlx::MySqlPool, time::Duration, MySqlStore, Session, SessionManagerLayer}; | ||
|
||
const COUNTER_KEY: &str = "counter"; | ||
|
||
#[derive(Serialize, Deserialize, Default)] | ||
struct Counter(usize); | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<(), Box<dyn std::error::Error>> { | ||
let database_url = std::option_env!("DATABASE_URL").expect("Missing DATABASE_URL."); | ||
let pool = MySqlPool::connect(database_url).await?; | ||
let session_store = MySqlStore::new(pool); | ||
session_store.migrate().await?; | ||
|
||
let deletion_task = tokio::task::spawn( | ||
session_store | ||
.clone() | ||
.continuously_delete_expired(tokio::time::Duration::from_secs(60)), | ||
); | ||
|
||
let session_service = ServiceBuilder::new() | ||
.layer(HandleErrorLayer::new(|_: BoxError| async { | ||
StatusCode::BAD_REQUEST | ||
})) | ||
.layer( | ||
SessionManagerLayer::new(session_store) | ||
.with_secure(false) | ||
.with_max_age(Duration::seconds(10)), | ||
); | ||
|
||
let app = Router::new() | ||
.route("/", get(handler)) | ||
.layer(session_service); | ||
|
||
let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); | ||
axum::Server::bind(&addr) | ||
.serve(app.into_make_service()) | ||
.await?; | ||
|
||
deletion_task.await??; | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn handler(session: Session) -> impl IntoResponse { | ||
let counter: Counter = session | ||
.get(COUNTER_KEY) | ||
.expect("Could not deserialize.") | ||
.unwrap_or_default(); | ||
|
||
session | ||
.insert(COUNTER_KEY, counter.0 + 1) | ||
.expect("Could not serialize."); | ||
|
||
format!("Current count: {}", counter.0) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
use async_trait::async_trait; | ||
use sqlx::MySqlPool; | ||
use time::OffsetDateTime; | ||
|
||
use crate::{ | ||
session::{SessionId, SessionRecord}, | ||
Session, SessionStore, SqlxStoreError, | ||
}; | ||
|
||
/// A PostgreSQL session store. | ||
#[derive(Clone, Debug)] | ||
pub struct MySqlStore { | ||
pool: MySqlPool, | ||
schema_name: String, | ||
table_name: String, | ||
} | ||
|
||
impl MySqlStore { | ||
/// Create a new PostgreSQL store with the provided connection pool. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust,no_run | ||
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore}; | ||
/// | ||
/// # tokio_test::block_on(async { | ||
/// let database_url = std::option_env!("DATABASE_URL").unwrap(); | ||
/// let pool = MySqlPool::connect(database_url).await.unwrap(); | ||
/// let session_store = MySqlStore::new(pool); | ||
/// # }) | ||
/// ``` | ||
pub fn new(pool: MySqlPool) -> Self { | ||
Self { | ||
pool, | ||
schema_name: "tower_sessions".to_string(), | ||
table_name: "session".to_string(), | ||
} | ||
} | ||
|
||
/// Migrate the session schema. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust,no_run | ||
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore}; | ||
/// | ||
/// # tokio_test::block_on(async { | ||
/// let database_url = std::option_env!("DATABASE_URL").unwrap(); | ||
/// let pool = MySqlPool::connect(database_url).await.unwrap(); | ||
/// let session_store = MySqlStore::new(pool); | ||
/// session_store.migrate().await.unwrap(); | ||
/// # }) | ||
/// ``` | ||
pub async fn migrate(&self) -> sqlx::Result<()> { | ||
let mut tx = self.pool.begin().await?; | ||
|
||
let create_schema_query = format!( | ||
"create schema if not exists {schema_name}", | ||
schema_name = self.schema_name, | ||
); | ||
sqlx::query(&create_schema_query).execute(&mut *tx).await?; | ||
|
||
let create_table_query = format!( | ||
r#" | ||
create table if not exists `{schema_name}`.`{table_name}` | ||
( | ||
id char(36) primary key not null, | ||
expiration_time timestamp null, | ||
data text not null | ||
) | ||
"#, | ||
schema_name = self.schema_name, | ||
table_name = self.table_name | ||
); | ||
sqlx::query(&create_table_query).execute(&mut *tx).await?; | ||
|
||
tx.commit().await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
#[cfg(feature = "tokio")] | ||
/// This function will keep running indefinitely, deleting expired rows and | ||
/// then waiting for the specified period before deleting again. | ||
/// | ||
/// Generally this will be used as a task, for example via | ||
/// `tokio::task::spawn`. | ||
/// | ||
/// # Arguments | ||
/// | ||
/// * `period` - The interval at which expired rows should be deleted. | ||
/// | ||
/// # Errors | ||
/// | ||
/// This function returns a `Result` that contains an error of type | ||
/// `sqlx::Error` if the deletion operation fails. | ||
/// | ||
/// # Examples | ||
/// | ||
/// ```rust,no_run | ||
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore}; | ||
/// | ||
/// # tokio_test::block_on(async { | ||
/// let database_url = std::option_env!("DATABASE_URL").unwrap(); | ||
/// let pool = MySqlPool::connect(database_url).await.unwrap(); | ||
/// let session_store = MySqlStore::new(pool); | ||
/// | ||
/// tokio::task::spawn( | ||
/// session_store | ||
/// .clone() | ||
/// .continuously_delete_expired(tokio::time::Duration::from_secs(60)), | ||
/// ); | ||
/// # }) | ||
/// ``` | ||
pub async fn continuously_delete_expired( | ||
self, | ||
period: tokio::time::Duration, | ||
) -> Result<(), sqlx::Error> { | ||
let mut interval = tokio::time::interval(period); | ||
loop { | ||
self.delete_expired().await?; | ||
interval.tick().await; | ||
} | ||
} | ||
|
||
async fn delete_expired(&self) -> sqlx::Result<()> { | ||
let query = format!( | ||
r#" | ||
delete from `{schema_name}`.`{table_name}` | ||
where expiration_time < utc_timestamp() | ||
"#, | ||
schema_name = self.schema_name, | ||
table_name = self.table_name | ||
); | ||
sqlx::query(&query).execute(&self.pool).await?; | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[async_trait] | ||
impl SessionStore for MySqlStore { | ||
type Error = SqlxStoreError; | ||
|
||
async fn save(&self, session_record: &SessionRecord) -> Result<(), Self::Error> { | ||
let query = format!( | ||
r#" | ||
insert into `{schema_name}`.`{table_name}` | ||
(id, expiration_time, data) values (?, ?, ?) | ||
on duplicate key update | ||
expiration_time = values(expiration_time), | ||
data = values(data) | ||
"#, | ||
schema_name = self.schema_name, | ||
table_name = self.table_name | ||
); | ||
sqlx::query(&query) | ||
.bind(&session_record.id().to_string()) | ||
.bind(session_record.expiration_time()) | ||
.bind(serde_json::to_string(&session_record.data())?) | ||
.execute(&self.pool) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
|
||
async fn load(&self, session_id: &SessionId) -> Result<Option<Session>, Self::Error> { | ||
let query = format!( | ||
r#" | ||
select * from `{schema_name}`.`{table_name}` | ||
where id = ? | ||
and (expiration_time is null or expiration_time > ?) | ||
"#, | ||
schema_name = self.schema_name, | ||
table_name = self.table_name | ||
); | ||
let record_value: Option<(String, Option<OffsetDateTime>, String)> = sqlx::query_as(&query) | ||
.bind(session_id.to_string()) | ||
.bind(OffsetDateTime::now_utc()) | ||
.fetch_optional(&self.pool) | ||
.await?; | ||
|
||
if let Some((session_id, expiration_time, data)) = record_value { | ||
let session_id = SessionId::try_from(session_id)?; | ||
let session_record = | ||
SessionRecord::new(session_id, expiration_time, serde_json::from_str(&data)?); | ||
Ok(Some(session_record.into())) | ||
} else { | ||
Ok(None) | ||
} | ||
} | ||
|
||
async fn delete(&self, session_id: &SessionId) -> Result<(), Self::Error> { | ||
let query = format!( | ||
r#"delete from `{schema_name}`.`{table_name}` where id = ?"#, | ||
schema_name = self.schema_name, | ||
table_name = self.table_name | ||
); | ||
sqlx::query(&query) | ||
.bind(&session_id.to_string()) | ||
.execute(&self.pool) | ||
.await?; | ||
|
||
Ok(()) | ||
} | ||
} |