diff --git a/russh-keys/src/lib.rs b/russh-keys/src/lib.rs index a819c63a..6ecaad61 100644 --- a/russh-keys/src/lib.rs +++ b/russh-keys/src/lib.rs @@ -920,42 +920,49 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux } } let agent_path_ = agent_path.clone(); - core.spawn(async move { - let mut listener = tokio::net::UnixListener::bind(&agent_path_).unwrap(); - - agent::server::serve( - Incoming { - listener: &mut listener, - }, - X {}, - ) - .await.unwrap() - }); let key = decode_secret_key(PKCS8_ENCRYPTED, Some("blabla")).unwrap(); core.block_on(async move { - let public = key.public_key(); - let stream = tokio::net::UnixStream::connect(&agent_path).await.unwrap(); - let mut client = agent::client::AgentClient::connect(stream); - client - .add_identity(&key, &[agent::Constraint::KeyLifetime { seconds: 60 }]) - .await - .unwrap(); - client.request_identities().await.unwrap(); - let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); - let len = buf.len(); - let buf = client.sign_request(&public, buf).await.unwrap(); - let (a, b) = buf.split_at(len); - if let ssh_key::public::KeyData::Ed25519 { .. } = public.key_data() { - let sig = &b[b.len() - 64..]; - let sig = ssh_key::Signature::new(key.algorithm(), sig).unwrap(); - assert!(Verifier::verify(public, a, &sig).is_ok()); - } + tokio::join!( + async move { + let mut listener = tokio::net::UnixListener::bind(&agent_path_).unwrap(); + + agent::server::serve( + Incoming { + listener: &mut listener, + connections_left: 1, + }, + X {}, + ) + .await + .unwrap() + }, + async move { + let public = key.public_key(); + let stream = tokio::net::UnixStream::connect(&agent_path).await.unwrap(); + let mut client = agent::client::AgentClient::connect(stream); + client + .add_identity(&key, &[agent::Constraint::KeyLifetime { seconds: 60 }]) + .await + .unwrap(); + client.request_identities().await.unwrap(); + let buf = russh_cryptovec::CryptoVec::from_slice(b"blabla"); + let len = buf.len(); + let buf = client.sign_request(&public, buf).await.unwrap(); + let (a, b) = buf.split_at(len); + if let ssh_key::public::KeyData::Ed25519 { .. } = public.key_data() { + let sig = &b[b.len() - 64..]; + let sig = ssh_key::Signature::new(key.algorithm(), sig).unwrap(); + assert!(Verifier::verify(public, a, &sig).is_ok()); + } + } + ); }) } #[cfg(unix)] struct Incoming<'a> { listener: &'a mut tokio::net::UnixListener, + connections_left: usize, } #[cfg(unix)] @@ -966,7 +973,12 @@ Cog3JMeTrb3LiPHgN6gU2P30MRp6L1j1J/MtlOAr5rux self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>, ) -> std::task::Poll> { - let (sock, _addr) = futures::ready!(self.get_mut().listener.poll_accept(cx))?; + let this = self.get_mut(); + if this.connections_left == 0 { + return std::task::Poll::Ready(None); + } + let (sock, _addr) = futures::ready!(this.listener.poll_accept(cx))?; + this.connections_left -= 1; std::task::Poll::Ready(Some(Ok(sock))) } }