Skip to content

Commit

Permalink
atrium-api cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
avdb13 committed Nov 21, 2024
1 parent b083c25 commit 4d17bba
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
54 changes: 43 additions & 11 deletions atrium-api/src/agent/atp_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub type AtpSession = crate::com::atproto::server::create_session::Output;
pub struct CredentialSession<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
store: Arc<inner::Store<S>>,
Expand All @@ -27,6 +28,7 @@ where
impl<S, T> CredentialSession<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
pub fn new(xrpc: T, store: S) -> Self {
Expand Down Expand Up @@ -58,7 +60,7 @@ where
.into(),
)
.await?;
self.store.set((), result.clone()).await.expect("todo");
self.store.set((), result.clone()).await.map_err(|e| Error::SessionStore(Box::new(e)))?;
if let Some(did_doc) = result
.did_doc
.as_ref()
Expand All @@ -73,17 +75,22 @@ where
&self,
session: AtpSession,
) -> Result<(), Error<crate::com::atproto::server::get_session::Error>> {
self.store.set((), session.clone()).await.expect("todo");
self.store.set((), session.clone()).await.map_err(|e| Error::SessionStore(Box::new(e)))?;
let result = self.api.com.atproto.server.get_session().await;
match result {
Ok(output) => {
assert_eq!(output.data.did, session.data.did);
if let Some(mut session) = self.store.get(&()).await.expect("todo") {
if let Some(mut session) =
self.store.get(&()).await.map_err(|e| Error::SessionStore(Box::new(e)))?
{
session.did_doc = output.data.did_doc.clone();
session.email = output.data.email;
session.email_confirmed = output.data.email_confirmed;
session.handle = output.data.handle;
self.store.set((), session).await.expect("todo");
self.store
.set((), session)
.await
.map_err(|e| Error::SessionStore(Box::new(e)))?;
}
if let Some(did_doc) = output
.data
Expand All @@ -96,7 +103,7 @@ where
Ok(())
}
Err(err) => {
self.store.clear().await.expect("todo");
self.store.clear().await.map_err(|e| Error::SessionStore(Box::new(e)))?;
Err(err)
}
}
Expand Down Expand Up @@ -125,7 +132,7 @@ where
}
/// Get the current session.
pub async fn get_session(&self) -> Option<AtpSession> {
self.store.get(&()).await.expect("todo")
self.store.get(&()).await.transpose().and_then(Result::ok)
}
/// Get the current endpoint.
pub async fn get_endpoint(&self) -> String {
Expand All @@ -146,6 +153,7 @@ where
pub struct AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
inner: CredentialSession<S, T>,
Expand All @@ -154,6 +162,7 @@ where
impl<S, T> AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
/// Create a new agent.
Expand All @@ -165,6 +174,7 @@ where
impl<S, T> Deref for AtpAgent<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
type Target = CredentialSession<S, T>;
Expand Down Expand Up @@ -365,7 +375,11 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.store.set((), session_data.clone().into()).await.expect("todo");
agent
.store
.set((), session_data.clone().into())
.await
.expect("set session should be succeeded");
let output = agent
.api
.com
Expand Down Expand Up @@ -399,7 +413,11 @@ mod tests {
..Default::default()
};
let agent = AtpAgent::new(client, MemoryMapStore::default());
agent.store.set((), session_data.clone().into()).await.expect("todo");
agent
.store
.set((), session_data.clone().into())
.await
.expect("set session should be succeeded");
let output = agent
.api
.com
Expand All @@ -410,7 +428,12 @@ mod tests {
.expect("get session should be succeeded");
assert_eq!(output.did.as_str(), "did:web:example.com");
assert_eq!(
agent.store.get(&()).await.expect("todo").map(|session| session.data.access_jwt),
agent
.store
.get(&())
.await
.expect("get session should be succeeded")
.map(|session| session.data.access_jwt),
Some("access".into())
);
}
Expand Down Expand Up @@ -438,7 +461,11 @@ mod tests {
};
let counts = Arc::clone(&client.counts);
let agent = Arc::new(AtpAgent::new(client, MemoryMapStore::default()));
agent.store.set((), session_data.clone().into()).await.expect("todo");
agent
.store
.set((), session_data.clone().into())
.await
.expect("set session should be succeeded");
let handles = (0..3).map(|_| {
let agent = Arc::clone(&agent);
tokio::spawn(async move { agent.api.com.atproto.server.get_session().await })
Expand All @@ -453,7 +480,12 @@ mod tests {
assert_eq!(output.did.as_str(), "did:web:example.com");
}
assert_eq!(
agent.store.get(&()).await.expect("todo").map(|session| session.data.access_jwt),
agent
.store
.get(&())
.await
.expect("get session should be succeeded")
.map(|session| session.data.access_jwt),
Some("access".into())
);
assert_eq!(
Expand Down
9 changes: 5 additions & 4 deletions atrium-api/src/agent/atp_agent/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ where
self.store.get_endpoint()
}
async fn authorization_token(&self, is_refresh: bool) -> Option<AuthorizationToken> {
self.store.get(&()).await.expect("todo").map(|session| {
self.store.get(&()).await.transpose().and_then(core::result::Result::ok).map(|session| {
AuthorizationToken::Bearer(if is_refresh {
session.data.refresh_jwt
} else {
Expand Down Expand Up @@ -157,13 +157,13 @@ where
}
async fn refresh_session_inner(&self) {
if let Ok(output) = self.call_refresh_session().await {
if let Some(mut session) = self.store.get(&()).await.expect("todo") {
if let Ok(Some(mut session)) = self.store.get(&()).await {
session.access_jwt = output.data.access_jwt;
session.did = output.data.did;
session.did_doc = output.data.did_doc.clone();
session.handle = output.data.handle;
session.refresh_jwt = output.data.refresh_jwt;
self.store.set((), session).await.expect("todo");
let _ = self.store.set((), session).await;
}
if let Some(did_doc) = output
.data
Expand All @@ -174,7 +174,7 @@ where
self.store.update_endpoint(&did_doc);
}
} else {
self.store.clear().await.expect("todo");
let _ = self.store.clear().await;
}
}
// same as `crate::client::com::atproto::server::Service::refresh_session()`
Expand Down Expand Up @@ -247,6 +247,7 @@ where
impl<S, T> XrpcClient for Client<S, T>
where
S: MapStore<(), AtpSession> + Send + Sync,
S::Error: Send + Sync + 'static,
T: XrpcClient + Send + Sync,
{
fn base_uri(&self) -> String {
Expand Down
2 changes: 2 additions & 0 deletions atrium-xrpc/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ where
SerdeJson(#[from] serde_json::Error),
#[error("serde_html_form error: {0}")]
SerdeHtmlForm(#[from] serde_html_form::ser::Error),
#[error("session store error: {0}")]
SessionStore(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("unexpected response type")]
UnexpectedResponseType,
}
Expand Down

0 comments on commit 4d17bba

Please sign in to comment.