Skip to content

Commit

Permalink
provide mysql store
Browse files Browse the repository at this point in the history
  • Loading branch information
maxcountryman committed Sep 23, 2023
1 parent 1f8338f commit 60bc7de
Show file tree
Hide file tree
Showing 5 changed files with 284 additions and 0 deletions.
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ redis-store = ["fred"]
sqlx-store = ["sqlx"]
sqlite-store = ["sqlx/sqlite", "sqlx-store"]
postgres-store = ["sqlx/postgres", "sqlx-store"]
mysql-store = ["sqlx/mysql", "sqlx-store"]
tokio = ["dep:tokio"]

[dependencies]
Expand Down Expand Up @@ -66,6 +67,10 @@ required-features = ["axum-core", "sqlite-store", "tokio"]
name = "postgres-store"
required-features = ["axum-core", "postgres-store", "tokio"]

[[example]]
name = "mysql-store"
required-features = ["axum-core", "mysql-store", "tokio"]

[[example]]
name = "strongly-typed"
required-features = ["axum-core", "memory-store"]
64 changes: 64 additions & 0 deletions examples/mysql-store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
use std::net::SocketAddr;

use axum::{
error_handling::HandleErrorLayer, response::IntoResponse, routing::get, BoxError, Router,
};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use tower::ServiceBuilder;
use tower_sessions::{sqlx::MySqlPool, time::Duration, MySqlStore, Session, SessionManagerLayer};

const COUNTER_KEY: &str = "counter";

#[derive(Serialize, Deserialize, Default)]
struct Counter(usize);

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let database_url = std::option_env!("DATABASE_URL").expect("Missing DATABASE_URL.");
let pool = MySqlPool::connect(database_url).await?;
let session_store = MySqlStore::new(pool);
session_store.migrate().await?;

let deletion_task = tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_secs(60)),
);

let session_service = ServiceBuilder::new()
.layer(HandleErrorLayer::new(|_: BoxError| async {
StatusCode::BAD_REQUEST
}))
.layer(
SessionManagerLayer::new(session_store)
.with_secure(false)
.with_max_age(Duration::seconds(10)),
);

let app = Router::new()
.route("/", get(handler))
.layer(session_service);

let addr = SocketAddr::from(([127, 0, 0, 1], 3000));
axum::Server::bind(&addr)
.serve(app.into_make_service())
.await?;

deletion_task.await??;

Ok(())
}

async fn handler(session: Session) -> impl IntoResponse {
let counter: Counter = session
.get(COUNTER_KEY)
.expect("Could not deserialize.")
.unwrap_or_default();

session
.insert(COUNTER_KEY, counter.0 + 1)
.expect("Could not serialize.");

format!("Current count: {}", counter.0)
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ pub use self::memory_store::MemoryStore;
#[cfg(feature = "redis-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "redis-store")))]
pub use self::redis_store::RedisStore;
#[cfg(feature = "mysql-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql-store")))]
pub use self::sqlx_store::MySqlStore;
#[cfg(feature = "postgres-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "postgres-store")))]
pub use self::sqlx_store::PostgresStore;
Expand Down
7 changes: 7 additions & 0 deletions src/sqlx_store.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "mysql-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql-store")))]
pub use self::mysql_store::MySqlStore;
#[cfg(feature = "postgres-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "postgres-store")))]
pub use self::postgres_store::PostgresStore;
Expand All @@ -14,6 +17,10 @@ mod sqlite_store;
#[cfg_attr(docsrs, doc(cfg(feature = "postgres-store")))]
mod postgres_store;

#[cfg(feature = "mysql-store")]
#[cfg_attr(docsrs, doc(cfg(feature = "mysql-store")))]
mod mysql_store;

/// An error type for SQLx stores.
#[allow(clippy::enum_variant_names)]
#[derive(thiserror::Error, Debug)]
Expand Down
205 changes: 205 additions & 0 deletions src/sqlx_store/mysql_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
use async_trait::async_trait;
use sqlx::MySqlPool;
use time::OffsetDateTime;

use crate::{
session::{SessionId, SessionRecord},
Session, SessionStore, SqlxStoreError,
};

/// A PostgreSQL session store.
#[derive(Clone, Debug)]
pub struct MySqlStore {
pool: MySqlPool,
schema_name: String,
table_name: String,
}

impl MySqlStore {
/// Create a new PostgreSQL store with the provided connection pool.
///
/// # Examples
///
/// ```rust,no_run
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore};
///
/// # tokio_test::block_on(async {
/// let database_url = std::option_env!("DATABASE_URL").unwrap();
/// let pool = MySqlPool::connect(database_url).await.unwrap();
/// let session_store = MySqlStore::new(pool);
/// # })
/// ```
pub fn new(pool: MySqlPool) -> Self {
Self {
pool,
schema_name: "tower_sessions".to_string(),
table_name: "session".to_string(),
}
}

/// Migrate the session schema.
///
/// # Examples
///
/// ```rust,no_run
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore};
///
/// # tokio_test::block_on(async {
/// let database_url = std::option_env!("DATABASE_URL").unwrap();
/// let pool = MySqlPool::connect(database_url).await.unwrap();
/// let session_store = MySqlStore::new(pool);
/// session_store.migrate().await.unwrap();
/// # })
/// ```
pub async fn migrate(&self) -> sqlx::Result<()> {
let mut tx = self.pool.begin().await?;

let create_schema_query = format!(
"create schema if not exists {schema_name}",
schema_name = self.schema_name,
);
sqlx::query(&create_schema_query).execute(&mut *tx).await?;

let create_table_query = format!(
r#"
create table if not exists `{schema_name}`.`{table_name}`
(
id char(36) primary key not null,
expiration_time timestamp null,
data text not null
)
"#,
schema_name = self.schema_name,
table_name = self.table_name
);
sqlx::query(&create_table_query).execute(&mut *tx).await?;

tx.commit().await?;

Ok(())
}

#[cfg(feature = "tokio")]
/// This function will keep running indefinitely, deleting expired rows and
/// then waiting for the specified period before deleting again.
///
/// Generally this will be used as a task, for example via
/// `tokio::task::spawn`.
///
/// # Arguments
///
/// * `period` - The interval at which expired rows should be deleted.
///
/// # Errors
///
/// This function returns a `Result` that contains an error of type
/// `sqlx::Error` if the deletion operation fails.
///
/// # Examples
///
/// ```rust,no_run
/// use tower_sessions::{sqlx::MySqlPool, MySqlStore};
///
/// # tokio_test::block_on(async {
/// let database_url = std::option_env!("DATABASE_URL").unwrap();
/// let pool = MySqlPool::connect(database_url).await.unwrap();
/// let session_store = MySqlStore::new(pool);
///
/// tokio::task::spawn(
/// session_store
/// .clone()
/// .continuously_delete_expired(tokio::time::Duration::from_secs(60)),
/// );
/// # })
/// ```
pub async fn continuously_delete_expired(
self,
period: tokio::time::Duration,
) -> Result<(), sqlx::Error> {
let mut interval = tokio::time::interval(period);
loop {
self.delete_expired().await?;
interval.tick().await;
}
}

async fn delete_expired(&self) -> sqlx::Result<()> {
let query = format!(
r#"
delete from `{schema_name}`.`{table_name}`
where expiration_time < utc_timestamp()
"#,
schema_name = self.schema_name,
table_name = self.table_name
);
sqlx::query(&query).execute(&self.pool).await?;
Ok(())
}
}

#[async_trait]
impl SessionStore for MySqlStore {
type Error = SqlxStoreError;

async fn save(&self, session_record: &SessionRecord) -> Result<(), Self::Error> {
let query = format!(
r#"
insert into `{schema_name}`.`{table_name}`
(id, expiration_time, data) values (?, ?, ?)
on duplicate key update
expiration_time = values(expiration_time),
data = values(data)
"#,
schema_name = self.schema_name,
table_name = self.table_name
);
sqlx::query(&query)
.bind(&session_record.id().to_string())
.bind(session_record.expiration_time())
.bind(serde_json::to_string(&session_record.data())?)
.execute(&self.pool)
.await?;

Ok(())
}

async fn load(&self, session_id: &SessionId) -> Result<Option<Session>, Self::Error> {
let query = format!(
r#"
select * from `{schema_name}`.`{table_name}`
where id = ?
and (expiration_time is null or expiration_time > ?)
"#,
schema_name = self.schema_name,
table_name = self.table_name
);
let record_value: Option<(String, Option<OffsetDateTime>, String)> = sqlx::query_as(&query)
.bind(session_id.to_string())
.bind(OffsetDateTime::now_utc())
.fetch_optional(&self.pool)
.await?;

if let Some((session_id, expiration_time, data)) = record_value {
let session_id = SessionId::try_from(session_id)?;
let session_record =
SessionRecord::new(session_id, expiration_time, serde_json::from_str(&data)?);
Ok(Some(session_record.into()))
} else {
Ok(None)
}
}

async fn delete(&self, session_id: &SessionId) -> Result<(), Self::Error> {
let query = format!(
r#"delete from `{schema_name}`.`{table_name}` where id = ?"#,
schema_name = self.schema_name,
table_name = self.table_name
);
sqlx::query(&query)
.bind(&session_id.to_string())
.execute(&self.pool)
.await?;

Ok(())
}
}

0 comments on commit 60bc7de

Please sign in to comment.