diff --git a/examples/simple_client.rs b/examples/simple_client.rs index bc28f6d..97014b4 100644 --- a/examples/simple_client.rs +++ b/examples/simple_client.rs @@ -35,7 +35,7 @@ impl russh::client::Handler for Handler { } } -#[tokio::main] +#[tokio::main(flavor = "current_thread")] pub async fn main() -> Result<(), Box> { env_logger::init(); // You can start a sftp server configured for this client with the following command: diff --git a/src/client/dir/close.rs b/src/client/dir/close.rs index 15b7c4f..57887ac 100644 --- a/src/client/dir/close.rs +++ b/src/client/dir/close.rs @@ -84,6 +84,7 @@ impl<'a> DirClosing<'a> { dir.buffer = None; dir.pending = None; if let Some(handle) = dir.handle.take() { + log::trace!("wait for closing"); let pending = dir.client.close(handle.clone()); return DirClosing(DirClosingState::Closing { dir, @@ -94,8 +95,10 @@ impl<'a> DirClosing<'a> { let stop = SftpClientStopping::new(&mut dir.client); if stop.is_stopped() { + log::trace!("closed and stopped"); DirClosing(DirClosingState::Closed) } else { + log::trace!("closed, wait for stopping"); DirClosing(DirClosingState::Stopping(stop)) } } diff --git a/src/client/error.rs b/src/client/error.rs index c4895ac..0798f02 100644 --- a/src/client/error.rs +++ b/src/client/error.rs @@ -47,6 +47,12 @@ impl From for Error { } } +impl From for Error { + fn from(value: crate::message::DecodeError) -> Self { + Self::WireFormat(value.inner) + } +} + impl From for std::io::Error { fn from(value: Error) -> Self { match value { diff --git a/src/client/file/close.rs b/src/client/file/close.rs index b8b59e6..4c58f25 100644 --- a/src/client/file/close.rs +++ b/src/client/file/close.rs @@ -80,6 +80,7 @@ impl<'a> FileClosing<'a> { file.pending = PendingOperation::None; if let Some(handle) = file.handle.take() { if let Some(handle) = Arc::into_inner(handle) { + log::trace!("wait for closing"); let pending = file.client.close(handle.clone()); return FileClosing(FileClosingState::Closing { file, @@ -91,8 +92,10 @@ impl<'a> FileClosing<'a> { let stop = SftpClientStopping::new(&mut file.client); if stop.is_stopped() { + log::trace!("closed and stopped"); FileClosing(FileClosingState::Closed) } else { + log::trace!("closed, wait for stopping"); FileClosing(FileClosingState::Stopping(stop)) } } diff --git a/src/client/mod.rs b/src/client/mod.rs index 64f3ad1..75235fd 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -21,9 +21,11 @@ use std::sync::Arc; use async_trait::async_trait; -use russh::{client::Msg, Channel, ChannelMsg}; -use tokio::sync::mpsc; +use russh::ChannelStream; +use russh::{client::Msg, Channel}; +use tokio::io::AsyncWrite; use tokio::task::JoinHandle; +use tokio::{io::AsyncRead, sync::mpsc}; use crate::message::{Init, Message, StatusCode, Version}; @@ -91,63 +93,47 @@ impl SftpClient { /// `ssh` can be a [`russh::Channel`]) /// or a [`russh::client::Handler`]. /// In case of the handler, it can be moved or borrowed. - pub async fn new(ssh: T) -> Result { - Self::with_channel(ssh.to_sftp_channel().await?).await + pub async fn new(ssh: T) -> Result { + Self::with_stream(ssh.into_sftp_stream().await?).await } /// Creates a new client from a [`russh::Channel`]. - pub async fn with_channel(mut channel: Channel) -> Result { - // Start SFTP subsystem - channel.request_subsystem(false, "sftp").await?; - + pub async fn with_stream( + mut stream: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static, + ) -> Result { // Init SFTP handshake - let init_message = Message::Init(Init { - version: 3, - extensions: Default::default(), - }); - let init_frame = init_message.encode(0)?; - channel.data(init_frame.as_ref()).await?; - - // Check handshake response - loop { - match channel.wait().await { - Some(ChannelMsg::Data { data }) => { - match Message::decode(data.as_ref()) { - // Valid response: continue - Ok(( - _, - Message::Version(Version { - version: 3, - extensions: _, - }), - )) => break, - - // Invalid responses: abort - Ok((_, Message::Version(_))) => { - return Err(StatusCode::BadMessage - .to_status("Invalid sftp version") - .into()); - } - Ok(_) => { - return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into()); - } - Err(err) => { - return Err(err.into()); - } - } - } - // Unrelated event has been received, looping is required - Some(_) => (), - // Channel has been closed - None => { - return Err(StatusCode::BadMessage - .to_status("Failed to start SFTP subsystem") - .into()); - } + receiver::write_msg( + &mut stream, + Message::Init(Init { + version: 3, + extensions: Default::default(), + }), + 3, + ) + .await?; + + match receiver::read_msg(&mut stream).await? { + // Valid response: continue + ( + _, + Message::Version(Version { + version: 3, + extensions: _, + }), + ) => (), + + // Invalid responses: abort + (_, Message::Version(_)) => { + return Err(StatusCode::BadMessage + .to_status("Invalid sftp version") + .into()); + } + _ => { + return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into()); } } - let (receiver, tx) = receiver::Receiver::new(channel); + let (receiver, tx) = receiver::Receiver::new(stream); let request_processor = tokio::spawn(receiver.run()); Ok(Self { @@ -165,27 +151,42 @@ impl std::fmt::Debug for SftpClient { /// Convert the object to a SSH channel #[async_trait] -pub trait ToSftpChannel { - async fn to_sftp_channel(self) -> Result, Error>; +pub trait IntoSftpStream { + type Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static; + async fn into_sftp_stream(self) -> Result; } #[async_trait] -impl ToSftpChannel for Channel { - async fn to_sftp_channel(self) -> Result, Error> { +impl IntoSftpStream for ChannelStream { + type Stream = ChannelStream; + async fn into_sftp_stream(self) -> Result { Ok(self) } } #[async_trait] -impl ToSftpChannel for &russh::client::Handle { - async fn to_sftp_channel(self) -> Result, Error> { - self.channel_open_session().await.map_err(Into::into) +impl IntoSftpStream for Channel { + type Stream = ChannelStream; + async fn into_sftp_stream(self) -> Result { + // Start SFTP subsystem + self.request_subsystem(false, "sftp").await?; + + Ok(self.into_stream()) + } +} + +#[async_trait] +impl IntoSftpStream for &russh::client::Handle { + type Stream = ChannelStream; + async fn into_sftp_stream(self) -> Result { + self.channel_open_session().await?.into_sftp_stream().await } } #[async_trait] -impl ToSftpChannel for russh::client::Handle { - async fn to_sftp_channel(self) -> Result, Error> { - (&self).to_sftp_channel().await +impl IntoSftpStream for russh::client::Handle { + type Stream = ChannelStream; + async fn into_sftp_stream(self) -> Result { + (&self).into_sftp_stream().await } } diff --git a/src/client/receiver.rs b/src/client/receiver.rs index b84fdfe..e25d623 100644 --- a/src/client/receiver.rs +++ b/src/client/receiver.rs @@ -15,138 +15,217 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; +use std::task::Poll; -use bytes::Buf; -use russh::{client::Msg, Channel, ChannelMsg}; +use bytes::{Buf, Bytes, BytesMut}; +use futures::{Stream, StreamExt}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{mpsc, oneshot}; use crate::client::Error; -use crate::message::Message; +use crate::message::{Message, StatusCode}; pub(super) type Response = Result; -pub(super) struct Request(pub(super) Message, pub(super) oneshot::Sender); +pub struct Request(pub(super) Message, pub(super) oneshot::Sender); -pub(super) struct Receiver { +pub(super) struct Receiver { onflight: HashMap>, next_id: u32, commands: mpsc::UnboundedReceiver, - channel: Channel, + stream: S, + response_size: Option, + response_buffer: BytesMut, } -impl Receiver { +impl Receiver { /// Create a new receiver - pub(super) fn new(channel: Channel) -> (Self, mpsc::UnboundedSender) { + pub(super) fn new(stream: S) -> (Self, mpsc::UnboundedSender) { let (tx, rx) = mpsc::unbounded_channel(); ( Self { onflight: HashMap::new(), next_id: 0, commands: rx, - channel, + stream, + response_size: None, + response_buffer: Default::default(), }, tx, ) } +} - /// Run a receiver until the ssh channel is closed or no more commands can be sent - pub(super) async fn run(mut self) { - log::debug!("Start SFTP client"); - loop { - tokio::select! { - // New request to send - request = self.commands.recv() => { - // If received null, the commands channel has been closed - let Some(Request(message, tx)) = request else { - log::debug!("Command channel closed"); - break; - }; - - self.process_command(message, tx).await; - } - - // New response received - response = self.channel.wait() => { - // If received null, the SSH channel has been closed - let Some(ChannelMsg::Data { data }) = response else { - log::debug!("SFTP channel closed"); - break; - }; +pub enum StreamItem { + Request(Request), + Response(Bytes), + Error(std::io::Error), +} - self.process_response(&data).await; +impl Stream for Receiver { + type Item = StreamItem; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + // Check if new commands have been sent + match self.commands.poll_recv(cx) { + Poll::Ready(Some(request)) => { + return Poll::Ready(Some(StreamItem::Request(request))); + } + Poll::Ready(None) => { + // If commands are closed and no request is on-flight, + // No more messages could be received + if self.onflight.is_empty() { + return Poll::Ready(None); } } - } + Poll::Pending => (), + }; - while !self.onflight.is_empty() { - // If received null, the SSH channel has been closed - let Some(ChannelMsg::Data { data }) = self.channel.wait().await else { - break; - }; - - self.process_response(&data).await; - } + // No command was available, trying to read responses from the stream + loop { + let new_len; + match self.response_size { + // A size has already been read from the stream + Some(response_size) => { + if self.response_buffer.len() >= response_size as usize { + self.response_size = None; + let response = self.response_buffer.split_to(response_size as usize); + return Poll::Ready(Some(StreamItem::Response(response.freeze()))); + } + new_len = response_size as usize; + } + // Must read the size of the frame from the stream + None => { + if self.response_buffer.len() >= std::mem::size_of::() { + let len = self.response_buffer.get_u32(); + self.response_size = Some(len); + continue; + } + new_len = std::mem::size_of::(); + } + } - self.commands.close(); - if let Err(err) = self.channel.close().await { - log::warn!("Error while closing SSH channel: {err:?}"); - } + let old_len = self.response_buffer.len(); - log::debug!("SFTP client stopped"); - } + // taking is required to avoid borrowinf multiple times `self` + let mut buffer = std::mem::take(&mut self.response_buffer); - /// Process a command request - async fn process_command(&mut self, message: Message, tx: oneshot::Sender) { - self.next_id += 1; - let id = self.next_id; + // tries to read the whole frame, or the next kilobyte + buffer.resize(new_len.max(1024), 0); + let mut read_buf = tokio::io::ReadBuf::new(&mut buffer[old_len..]); + let read = Pin::new(&mut self.stream).poll_read(cx, &mut read_buf); - log::trace!("Request #{id}: {message:?}"); + // Adjust buffer size according to what was read + let len = read_buf.filled().len(); + buffer.resize(old_len + len, 0); + self.response_buffer = buffer; - match message.encode(id) { - Ok(frame) => match self.channel.data(frame.as_ref()).await { - Ok(()) => { - self.onflight.insert(id, tx); + // Check status of reading + match read { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => { + return Poll::Ready(Some(StreamItem::Error(err))); } - Err(err) => { - log::debug!("Could not send request #{id}: {err:?}"); - send_message(tx, Err(err.into())); + Poll::Pending => { + return Poll::Pending; } - }, - Err(err) => { - log::debug!("Could not encode request #{id}: {err:?}"); - send_message(tx, Err(err.into())); + } + + // EoF + if len == old_len { + return Poll::Ready(None); } } } +} - /// Process a SSH response - async fn process_response(&mut self, data: &[u8]) { - match Message::decode(data) { - Ok((id, message)) => { - log::trace!("Response #{id}: {message:?}"); - if let Some(tx) = self.onflight.remove(&id) { - send_message(tx, Ok(message)); - } else { - log::error!("SFTP Error: Received a reply with an invalid id"); +impl Receiver { + /// Run a receiver until the ssh channel is closed or no more commands can be sent + pub(super) async fn run(mut self) { + log::debug!("Start SFTP client"); + + // Read all the events + while let Some(event) = self.next().await { + match event { + // New request was received + StreamItem::Request(Request(message, tx)) => { + self.next_id += 1; + let id = self.next_id; + + log::trace!("Request #{id}: {message:?}"); + + match write_msg(&mut self.stream, message, id).await { + Ok(()) => { + self.onflight.insert(id, tx); + } + Err(err) => { + log::debug!("Could not send request #{id}: {err:?}"); + send_response(tx, Err(err)); + } + } } - } - Err(err) => { - log::trace!("Failed to parse message: {data:?}"); - if let Some(mut buf) = data.get(5..9) { - let id = buf.get_u32(); - if let Some(tx) = self.onflight.remove(&id) { - send_message(tx, Err(err.into())); - } else { - log::error!("SFTP Error: Received a reply with an invalid id"); + + // New response was received + StreamItem::Response(response) => match Message::decode_raw(response.as_ref()) { + Ok((id, message)) => { + log::trace!("Response #{id}: {message:?}"); + if let Some(tx) = self.onflight.remove(&id) { + send_response(tx, Ok(message)); + } else { + log::error!("SFTP Error: Received a reply with an invalid id"); + } + } + Err(err) => { + log::trace!("Failed to parse message: {response:?}: {err:?}"); + if let Some(id) = err.id { + if let Some(tx) = self.onflight.remove(&id) { + send_response(tx, Err(err.into())); + } else { + log::error!("SFTP Error: Received a reply with an invalid id"); + } + } else { + log::error!("SFTP Error: Received a bad reply"); + } + } + }, + + // Error while receiving + StreamItem::Error(err) => { + log::error!("Error while waiting for SFTP response: {err:?}"); + match err.kind() { + std::io::ErrorKind::WouldBlock => (), + std::io::ErrorKind::TimedOut => (), + std::io::ErrorKind::WriteZero => (), + std::io::ErrorKind::Interrupted => (), + std::io::ErrorKind::OutOfMemory => (), + _ => break, } - } else { - log::error!("SFTP Error: Received a bad reply"); } } } + + for (_, tx) in self.onflight { + send_response( + tx, + Err(Error::Sftp(StatusCode::ConnectionLost.to_status( + "Could not receive response: SFTP stream stopped", + ))), + ); + } + + self.commands.close(); + if let Err(err) = self.stream.shutdown().await { + log::warn!("Error while closing SSH channel: {err:?}"); + } + + log::debug!("SFTP client stopped"); } } -fn send_message(tx: oneshot::Sender, msg: Response) { +fn send_response(tx: oneshot::Sender, msg: Response) { match tx.send(msg) { Ok(()) => (), Err(err) => { @@ -154,3 +233,23 @@ fn send_message(tx: oneshot::Sender, msg: Response) { } } } + +pub(super) async fn write_msg( + stream: &mut (impl AsyncWrite + Unpin), + msg: Message, + id: u32, +) -> Result<(), Error> { + let frame = msg.encode(id)?; + Ok(stream.write_all(frame.as_ref()).await?) +} + +pub(super) async fn read_msg( + stream: &mut (impl AsyncRead + Unpin), +) -> Result<(u32, Message), Error> { + let length = stream.read_u32().await?; + + let mut bytes = vec![0u8; length as usize]; + stream.read_exact(bytes.as_mut_slice()).await?; + + Ok(Message::decode_raw(bytes.as_slice())?) +} diff --git a/src/client/request.rs b/src/client/request.rs index c9b2405..3a10901 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -127,6 +127,7 @@ impl SftpClient { Ok(Message::Status(status)) => SftpFuture::Error(status.into()), Ok(msg) => { let (tx, rx) = oneshot::channel(); + log::trace!("Sending: {msg:?}"); match commands.send(super::receiver::Request(msg, tx)) { Ok(()) => SftpFuture::Pending { future: rx, diff --git a/src/client/stop.rs b/src/client/stop.rs index 964317a..1591536 100644 --- a/src/client/stop.rs +++ b/src/client/stop.rs @@ -76,6 +76,10 @@ impl<'a> SftpClientStopping<'a> { request_processor: Some(request_processor), }; } + + log::trace!("Client still running"); + } else { + log::trace!("stopped"); } // If the current client is not the last of the session, nothing to wait diff --git a/src/message/mod.rs b/src/message/mod.rs index d756b27..a0866c6 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -25,6 +25,7 @@ use std::borrow::Cow; use bytes::{Buf, BufMut, Bytes}; use serde::{ser::SerializeTuple, Deserialize, Serialize}; +use thiserror::Error; use crate::wire::{Error, SftpDecoder, SftpEncoder}; @@ -121,6 +122,16 @@ macro_rules! messages { } } + impl TryFrom for MessageKind { + type Error = u8; + fn try_from(value: u8) -> Result { + match value { + $($discriminant => Ok(MessageKind::$name),)* + value => Err(value), + } + } + } + impl From for MessageKind { fn from(value: Message) -> Self { value.kind() @@ -341,15 +352,51 @@ impl Message { Ok(encoder.buf.into()) } - pub fn decode(mut buf: &[u8]) -> Result<(u32, Self), Error> { + pub fn decode(mut buf: &[u8]) -> Result<(u32, Self), DecodeError> { let frame_length = buf.get_u32() as usize; // Limit the read to this very frame - let mut decoder = SftpDecoder::new(&buf[0..frame_length]); + Message::decode_raw(&buf[0..frame_length]) + } - let message_with_id = MessageWithId::deserialize(&mut decoder).map_err(Into::into)?; + pub fn decode_raw(mut buf: &[u8]) -> Result<(u32, Self), DecodeError> { + let mut decoder = SftpDecoder::new(buf); + + match MessageWithId::deserialize(&mut decoder) { + Ok(message_with_id) => Ok((message_with_id.id, message_with_id.message.into_owned())), + Err(err) => { + let mut err = DecodeError::from(err); + + if buf.remaining() >= std::mem::size_of::() { + if let Ok(kind) = MessageKind::try_from(buf.get_u8()) { + err.kind = Some(kind); + } + } + if buf.remaining() >= std::mem::size_of::() { + err.id = Some(buf.get_u32()); + } - Ok((message_with_id.id, message_with_id.message.into_owned())) + Err(err) + } + } + } +} + +#[derive(Debug, Error)] +#[error("{inner:?} (id: {id:?}, kind: {kind:?})")] +pub struct DecodeError { + pub inner: crate::wire::Error, + pub id: Option, + pub kind: Option, +} + +impl From for DecodeError { + fn from(value: crate::wire::Error) -> Self { + Self { + inner: value, + id: None, + kind: None, + } } }