Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support AsyncRead+Write for receiver #19

Merged
merged 1 commit into from
May 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/client/dir/close.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ impl<'a> DirClosing<'a> {
dir.buffer = None;
dir.pending = None;
if let Some(handle) = dir.handle.take() {
log::trace!("wait for closing");
let pending = dir.client.close(handle.clone());
return DirClosing(DirClosingState::Closing {
dir,
Expand All @@ -94,8 +95,10 @@ impl<'a> DirClosing<'a> {

let stop = SftpClientStopping::new(&mut dir.client);
if stop.is_stopped() {
log::trace!("closed and stopped");
DirClosing(DirClosingState::Closed)
} else {
log::trace!("closed, wait for stopping");
DirClosing(DirClosingState::Stopping(stop))
}
}
Expand Down
6 changes: 6 additions & 0 deletions src/client/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ impl From<russh::Error> for Error {
}
}

impl From<crate::message::DecodeError> for Error {
fn from(value: crate::message::DecodeError) -> Self {
Self::WireFormat(value.inner)
}
}

impl From<Error> for std::io::Error {
fn from(value: Error) -> Self {
match value {
Expand Down
3 changes: 3 additions & 0 deletions src/client/file/close.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ impl<'a> FileClosing<'a> {
file.pending = PendingOperation::None;
if let Some(handle) = file.handle.take() {
if let Some(handle) = Arc::into_inner(handle) {
log::trace!("wait for closing");
let pending = file.client.close(handle.clone());
return FileClosing(FileClosingState::Closing {
file,
Expand All @@ -91,8 +92,10 @@ impl<'a> FileClosing<'a> {

let stop = SftpClientStopping::new(&mut file.client);
if stop.is_stopped() {
log::trace!("closed and stopped");
FileClosing(FileClosingState::Closed)
} else {
log::trace!("closed, wait for stopping");
FileClosing(FileClosingState::Stopping(stop))
}
}
Expand Down
129 changes: 65 additions & 64 deletions src/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
use std::sync::Arc;

use async_trait::async_trait;
use russh::{client::Msg, Channel, ChannelMsg};
use tokio::sync::mpsc;
use russh::ChannelStream;
use russh::{client::Msg, Channel};
use tokio::io::AsyncWrite;
use tokio::task::JoinHandle;
use tokio::{io::AsyncRead, sync::mpsc};

use crate::message::{Init, Message, StatusCode, Version};

Expand Down Expand Up @@ -88,66 +90,50 @@ impl SftpClient {

/// Creates a new client from a ssh connection.
///
/// `ssh` can be a [`russh::Channel<Msg>`])
/// `ssh` can be a [`russh::Channel<Msg>`]
/// or a [`russh::client::Handler`].
/// In case of the handler, it can be moved or borrowed.
pub async fn new<T: ToSftpChannel>(ssh: T) -> Result<Self, Error> {
Self::with_channel(ssh.to_sftp_channel().await?).await
pub async fn new<T: IntoSftpStream>(ssh: T) -> Result<Self, Error> {
Self::with_stream(ssh.into_sftp_stream().await?).await
}

/// Creates a new client from a [`russh::Channel<Msg>`].
pub async fn with_channel(mut channel: Channel<Msg>) -> Result<Self, Error> {
// Start SFTP subsystem
channel.request_subsystem(false, "sftp").await?;

/// Creates a new client from a stream ([`AsyncRead`] + [`AsyncWrite`]).
pub async fn with_stream(
mut stream: impl AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
) -> Result<Self, Error> {
// Init SFTP handshake
let init_message = Message::Init(Init {
version: 3,
extensions: Default::default(),
});
let init_frame = init_message.encode(0)?;
channel.data(init_frame.as_ref()).await?;

// Check handshake response
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
)) => break,

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version")
.into());
}
Ok(_) => {
return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into());
}
Err(err) => {
return Err(err.into());
}
}
}
// Unrelated event has been received, looping is required
Some(_) => (),
// Channel has been closed
None => {
return Err(StatusCode::BadMessage
.to_status("Failed to start SFTP subsystem")
.into());
}
receiver::write_msg(
&mut stream,
Message::Init(Init {
version: 3,
extensions: Default::default(),
}),
3,
)
.await?;

match receiver::read_msg(&mut stream).await? {
// Valid response: continue
(
_,
Message::Version(Version {
version: 3,
extensions: _,
}),
) => (),

// Invalid responses: abort
(_, Message::Version(_)) => {
return Err(StatusCode::BadMessage
.to_status("Invalid sftp version")
.into());
}
_ => {
return Err(StatusCode::BadMessage.to_status("Bad SFTP init").into());
}
}

let (receiver, tx) = receiver::Receiver::new(channel);
let (receiver, tx) = receiver::Receiver::new(stream);
let request_processor = tokio::spawn(receiver.run());

Ok(Self {
Expand All @@ -165,27 +151,42 @@ impl std::fmt::Debug for SftpClient {

/// Convert the object to a SSH channel
#[async_trait]
pub trait ToSftpChannel {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error>;
pub trait IntoSftpStream {
type Stream: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error>;
}

#[async_trait]
impl ToSftpChannel for Channel<Msg> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
impl IntoSftpStream for ChannelStream<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
Ok(self)
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for &russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
self.channel_open_session().await.map_err(Into::into)
impl IntoSftpStream for Channel<Msg> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
// Start SFTP subsystem
self.request_subsystem(false, "sftp").await?;

Ok(self.into_stream())
}
}

#[async_trait]
impl<H: russh::client::Handler> IntoSftpStream for &russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
self.channel_open_session().await?.into_sftp_stream().await
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, Error> {
(&self).to_sftp_channel().await
impl<H: russh::client::Handler> IntoSftpStream for russh::client::Handle<H> {
type Stream = ChannelStream<Msg>;
async fn into_sftp_stream(self) -> Result<Self::Stream, Error> {
(&self).into_sftp_stream().await
}
}
Loading