Skip to content

Commit

Permalink
Support AsyncRead+Write for receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
lemaitre-aneo committed Apr 13, 2024
1 parent 88f6bb2 commit c402efc
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 152 deletions.
2 changes: 1 addition & 1 deletion examples/simple_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl russh::client::Handler for Handler {
}
}

#[tokio::main]
#[tokio::main(flavor = "current_thread")]
pub async fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::init();
// You can start a sftp server configured for this client with the following command:
Expand Down
3 changes: 3 additions & 0 deletions src/client/dir/close.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl<'a> DirClosing<'a> {
dir.buffer = None;
dir.pending = None;
if let Some(handle) = dir.handle.take() {
log::trace!("wait for closing");
let pending = dir.client.close(handle.clone());
return DirClosing(DirClosingState::Closing {
dir,
Expand All @@ -94,8 +95,10 @@ impl<'a> DirClosing<'a> {

let stop = SftpClientStopping::new(&mut dir.client);
if stop.is_stopped() {
log::trace!("closed and stopped");
DirClosing(DirClosingState::Closed)
} else {
log::trace!("closed, wait for stopping");
DirClosing(DirClosingState::Stopping(stop))
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ impl From<russh::Error> for Error {
}
}

impl From<crate::message::DecodeError> for Error {
fn from(value: crate::message::DecodeError) -> Self {
Self::WireFormat(value.inner)
}
}

impl From<Error> for std::io::Error {
fn from(value: Error) -> Self {
match value {
Expand Down
3 changes: 3 additions & 0 deletions src/client/file/close.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ impl<'a> FileClosing<'a> {
file.pending = PendingOperation::None;
if let Some(handle) = file.handle.take() {
if let Some(handle) = Arc::into_inner(handle) {
log::trace!("wait for closing");
let pending = file.client.close(handle.clone());
return FileClosing(FileClosingState::Closing {
file,
Expand All @@ -91,8 +92,10 @@ impl<'a> FileClosing<'a> {

let stop = SftpClientStopping::new(&mut file.client);
if stop.is_stopped() {
log::trace!("closed and stopped");
FileClosing(FileClosingState::Closed)
} else {
log::trace!("closed, wait for stopping");
FileClosing(FileClosingState::Stopping(stop))
}
}
Expand Down
125 changes: 63 additions & 62 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
use std::sync::Arc;

use async_trait::async_trait;
use russh::{client::Msg, Channel, ChannelMsg};
use tokio::sync::mpsc;
use russh::ChannelStream;
use russh::{client::Msg, Channel};
use tokio::io::AsyncWrite;
use tokio::task::JoinHandle;
use tokio::{io::AsyncRead, sync::mpsc};

use crate::message::{Init, Message, StatusCode, Version};

Expand Down Expand Up @@ -91,63 +93,47 @@ impl SftpClient {
/// `ssh` can be a [`russh::Channel<Msg>`])
/// or a [`russh::client::Handler`].
/// In case of the handler, it can be moved or borrowed.
pub async fn new<T: ToSftpChannel>(ssh: T) -> Result<Self, Error> {
Self::with_channel(ssh.to_sftp_channel().await?).await
pub async fn new<T: IntoSftpStream>(ssh: T) -> Result<Self, Error> {
Self::with_stream(ssh.into_sftp_stream().await?).await
}

/// Creates a new client from a [`russh::Channel<Msg>`].
pub async fn with_channel(mut channel: Channel<Msg>) -> Result<Self, Error> {
// Start SFTP subsystem
channel.request_subsystem(false, "sftp").await?;

pub async fn with_stream(
mut stream: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
) -> Result<Self, Error> {
// Init SFTP handshake
let init_message = Message::Init(Init {
version: 3,
extensions: Default::default(),
});
let init_frame = init_message.encode(0)?;
channel.data(init_frame.as_ref()).await?;

// Check handshake response
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
)) => break,

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version")
.into());
}
Ok(_) => {
return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into());
}
Err(err) => {
return Err(err.into());
}
}
}
// Unrelated event has been received, looping is required
Some(_) => (),
// Channel has been closed
None => {
return Err(StatusCode::BadMessage
.to_status("Failed to start SFTP subsystem")
.into());
}
receiver::write_msg(
&mut stream,
Message::Init(Init {
version: 3,
extensions: Default::default(),
}),
3,
)
.await?;

match receiver::read_msg(&mut stream).await? {
// Valid response: continue
(
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
) => (),

// Invalid responses: abort
(_, Message::Version(_)) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version")
.into());
}
_ => {
return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into());
}
}

let (receiver, tx) = receiver::Receiver::new(channel);
let (receiver, tx) = receiver::Receiver::new(stream);
let request_processor = tokio::spawn(receiver.run());

Ok(Self {
Expand All @@ -165,27 +151,42 @@ impl std::fmt::Debug for SftpClient {

/// Convert the object to a SSH channel
#[async_trait]
pub trait ToSftpChannel {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error>;
pub trait IntoSftpStream {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error>;
}

#[async_trait]
impl ToSftpChannel for Channel<Msg> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
impl IntoSftpStream for ChannelStream<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
Ok(self)
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for &russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
self.channel_open_session().await.map_err(Into::into)
impl IntoSftpStream for Channel<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
// Start SFTP subsystem
self.request_subsystem(false, "sftp").await?;

Ok(self.into_stream())
}
}

#[async_trait]
impl<H: russh::client::Handler> IntoSftpStream for &russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
self.channel_open_session().await?.into_sftp_stream().await
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
(&self).to_sftp_channel().await
impl<H: russh::client::Handler> IntoSftpStream for russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
(&self).into_sftp_stream().await
}
}
Loading

0 comments on commit c402efc

Please sign in to comment.