From ede7db2ff5817efb1f325ab5fff4eb618f154e6a Mon Sep 17 00:00:00 2001 From: Florian Lemaitre Date: Sat, 30 Mar 2024 18:09:53 +0100 Subject: [PATCH 1/4] Add client example --- Cargo.lock | 38 +++++++++++ Cargo.toml | 7 +- examples/simple_client.rs | 30 +++++++++ src/client.rs | 134 ++++++++++++++++++++++++++++---------- src/message/handle.rs | 22 +++++++ src/message/path.rs | 16 +++++ 6 files changed, 212 insertions(+), 35 deletions(-) create mode 100644 examples/simple_client.rs diff --git a/Cargo.lock b/Cargo.lock index 0a9a9e7..dacf6ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,28 @@ dependencies = [ "subtle", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.79" @@ -987,10 +1009,13 @@ dependencies = [ name = "rusftp" version = "0.1.0" dependencies = [ + "async-trait", "bytes", "russh", + "russh-keys", "serde", "tokio", + "tokio-test", ] [[package]] @@ -1279,6 +1304,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/Cargo.toml b/Cargo.toml index 0f0b8ac..0a6a1e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,12 @@ include = ["**/*.rs", "Cargo.toml", "LICENSE"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1" bytes = { version = "1.6", features = ["serde"] } -tokio = "1.36" russh = "0.43" serde = "1.0" +tokio = "1.36" + +[dev-dependencies] +russh-keys = "0.43" +tokio-test = "0.4" diff --git a/examples/simple_client.rs b/examples/simple_client.rs new file mode 100644 index 0000000..2fa4558 --- /dev/null +++ b/examples/simple_client.rs @@ -0,0 +1,30 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +struct Handler; + +#[async_trait] +impl russh::client::Handler for Handler { + type Error = russh::Error; + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + Ok(true) + } +} + +#[tokio::main] +pub async fn main() { + let config = Arc::new(russh::client::Config::default()); + let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler) + .await + .unwrap(); + + ssh.authenticate_password("user", "pass").await.unwrap(); + let sftp = rusftp::SftpClient::new(ssh).await.unwrap(); + + let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap(); + println!("CWD: {:?}", cwd); +} diff --git a/src/client.rs b/src/client.rs index 638341c..0069c75 100644 --- a/src/client.rs +++ b/src/client.rs @@ -16,6 +16,7 @@ use std::collections::HashMap; +use async_trait::async_trait; use russh::client::Msg; use russh::Channel; use russh::ChannelMsg; @@ -24,6 +25,33 @@ use tokio::sync::{mpsc, oneshot}; use crate::{message, Message}; +/// SFTP client +/// +/// ``` +/// # use std::sync::Arc; +/// # use async_trait::async_trait; +/// # tokio_test::block_on(async { +/// struct Handler; +/// +/// #[async_trait] +/// impl russh::client::Handler for Handler { +/// type Error = russh::Error; +/// async fn check_server_key( +/// &mut self, +/// server_public_key: &russh_keys::key::PublicKey, +/// ) -> Result { +/// Ok(true) +/// } +/// } +/// let config = Arc::new(russh::client::Config::default()); +/// let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler).await.unwrap(); +/// ssh.authenticate_password("user", "pass").await.unwrap(); +/// +/// let sftp = rusftp::SftpClient::new(&ssh).await.unwrap(); +/// let stat = sftp.stat(rusftp::Stat{path: ".".into()}).await.unwrap(); +/// println!("stat '.': {stat:?}"); +/// # }) +/// ``` pub struct SftpClient { commands: mpsc::UnboundedSender<(Message, oneshot::Sender)>, } @@ -59,7 +87,10 @@ macro_rules! command { } impl SftpClient { - pub async fn new(mut channel: Channel) -> Result { + pub async fn new(ssh: T) -> Result { + Self::with_channel(ssh.to_sftp_channel().await?).await + } + pub async fn with_channel(mut channel: Channel) -> Result { // Start SFTP subsystem match channel.request_subsystem(false, "sftp").await { Ok(_) => (), @@ -88,41 +119,46 @@ impl SftpClient { })?; // Check handshake response - match channel.wait().await { - Some(ChannelMsg::Data { data }) => { - match Message::decode(data.as_ref()) { - // Valid response: continue - Ok(( - _, - Message::Version(message::Version { - version: 3, - extensions: _, - }), - )) => (), - - // Invalid responses: abort - Ok((_, Message::Version(_))) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Invalid sftp version", - )); - } - Ok(_) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Bad SFTP init", - )); - } - Err(err) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); + loop { + match channel.wait().await { + Some(ChannelMsg::Data { data }) => { + match Message::decode(data.as_ref()) { + // Valid response: continue + Ok(( + _, + Message::Version(message::Version { + version: 3, + extensions: _, + }), + )) => break, + + // Invalid responses: abort + Ok((_, Message::Version(_))) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Invalid sftp version", + )); + } + Ok(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Bad SFTP init", + )); + } + Err(err) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); + } } } - } - _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Failed to start SFTP subsystem", - )); + // Unrelated event has been received, looping is required + Some(_) => (), + // Channel has been closed + None => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Failed to start SFTP subsystem", + )); + } } } @@ -226,3 +262,33 @@ impl SftpClient { command!(readlink: ReadLink -> Name); command!(symlink: Symlink); } + +#[async_trait] +pub trait ToSftpChannel { + async fn to_sftp_channel(self) -> Result, std::io::Error>; +} + +#[async_trait] +impl ToSftpChannel for Channel { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + Ok(self) + } +} + +#[async_trait] +impl ToSftpChannel for &russh::client::Handle { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + match self.channel_open_session().await { + Ok(channel) => Ok(channel), + Err(russh::Error::IO(err)) => Err(err), + Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), + } + } +} + +#[async_trait] +impl ToSftpChannel for russh::client::Handle { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + (&self).to_sftp_channel().await + } +} diff --git a/src/message/handle.rs b/src/message/handle.rs index 67a001c..159ab72 100644 --- a/src/message/handle.rs +++ b/src/message/handle.rs @@ -14,12 +14,34 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use bytes::Bytes; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] pub struct Handle(pub Bytes); +impl> From for Handle { + fn from(value: T) -> Self { + Self(value.into()) + } +} + +impl Deref for Handle { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl AsRef<[u8]> for Handle { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + #[cfg(test)] mod test { use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID}; diff --git a/src/message/path.rs b/src/message/path.rs index d8a9029..dc1ad9e 100644 --- a/src/message/path.rs +++ b/src/message/path.rs @@ -14,6 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -26,6 +28,20 @@ impl> From for Path { } } +impl Deref for Path { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl AsRef<[u8]> for Path { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + #[cfg(test)] mod test { use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID}; From a01ea036941e59fad557cf24a0015423bd3526bd Mon Sep 17 00:00:00 2001 From: Florian Lemaitre Date: Sat, 30 Mar 2024 18:29:04 +0100 Subject: [PATCH 2/4] Catch bad response --- examples/simple_client.rs | 3 +++ src/client.rs | 20 ++++++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/examples/simple_client.rs b/examples/simple_client.rs index 2fa4558..baee742 100644 --- a/examples/simple_client.rs +++ b/examples/simple_client.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use async_trait::async_trait; +use rusftp::RealPath; struct Handler; @@ -27,4 +28,6 @@ pub async fn main() { let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap(); println!("CWD: {:?}", cwd); + + sftp.realpath(RealPath { path: ".".into() }).await.unwrap(); } diff --git a/src/client.rs b/src/client.rs index 0069c75..f5467c2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,14 +15,16 @@ // limitations under the License. use std::collections::HashMap; +use std::future::Future; use async_trait::async_trait; +use bytes::Buf; use russh::client::Msg; use russh::Channel; use russh::ChannelMsg; -use std::future::Future; use tokio::sync::{mpsc, oneshot}; +use crate::StatusCode; use crate::{message, Message}; /// SFTP client @@ -210,7 +212,21 @@ impl SftpClient { } } Err(err) => { - eprintln!("SFTP Error: Could not decode server frame: {err}"); + if let Some(mut buf) = data.as_ref().get(5..9){ + + let id = buf.get_u32(); + if let Some(tx) = onflight.remove(&id) { + _ = tx.send(Message::Status(crate::Status { + code: StatusCode::BadMessage as u32, + error: err.to_string().into(), + language: "en".into(), + })); + } else { + eprintln!("SFTP Error: Received a reply with an invalid id"); + } + } else { + eprintln!("SFTP Error: Received a bad reply"); + } } } }, From fad9840c3d74ef45138c0ef657cbca8ea4f83390 Mon Sep 17 00:00:00 2001 From: Florian Lemaitre Date: Sat, 30 Mar 2024 19:02:02 +0100 Subject: [PATCH 3/4] fix name --- examples/simple_client.rs | 3 +- src/message/mod.rs | 2 +- src/message/name.rs | 105 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 104 insertions(+), 6 deletions(-) diff --git a/examples/simple_client.rs b/examples/simple_client.rs index baee742..246923c 100644 --- a/examples/simple_client.rs +++ b/examples/simple_client.rs @@ -29,5 +29,6 @@ pub async fn main() { let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap(); println!("CWD: {:?}", cwd); - sftp.realpath(RealPath { path: ".".into() }).await.unwrap(); + let realpath = sftp.realpath(RealPath { path: ".".into() }).await.unwrap(); + println!("RealPath: {:?}", realpath); } diff --git a/src/message/mod.rs b/src/message/mod.rs index 822c406..6ab0a99 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -66,7 +66,7 @@ pub use handle::Handle; pub use init::Init; pub use lstat::LStat; pub use mkdir::MkDir; -pub use name::Name; +pub use name::{Name, NameEntry}; pub use open::{pflags, Open}; pub use opendir::OpenDir; pub use path::Path; diff --git a/src/message/name.rs b/src/message/name.rs index 8fbf2a0..14f3482 100644 --- a/src/message/name.rs +++ b/src/message/name.rs @@ -14,17 +14,111 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut, Index, IndexMut}, + slice::SliceIndex, +}; + use serde::{Deserialize, Serialize}; use super::{Attrs, Path}; #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] -pub struct Name { +pub struct NameEntry { pub filename: Path, pub long_name: Path, pub attrs: Attrs, } +#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] +pub struct Name(pub Vec); + +impl IntoIterator for Name { + type Item = NameEntry; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a Name { + type Item = &'a NameEntry; + + type IntoIter = <&'a Vec as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'a> IntoIterator for &'a mut Name { + type Item = &'a mut NameEntry; + + type IntoIter = <&'a mut Vec as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + +impl FromIterator for Name { + fn from_iter>(iter: T) -> Self { + Self(Vec::from_iter(iter)) + } +} + +impl> Index for Name { + type Output = I::Output; + + fn index(&self, index: I) -> &Self::Output { + &self.0[index] + } +} + +impl> IndexMut for Name { + fn index_mut(&mut self, index: I) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl AsRef<[NameEntry]> for Name { + fn as_ref(&self) -> &[NameEntry] { + &self.0 + } +} +impl AsMut<[NameEntry]> for Name { + fn as_mut(&mut self) -> &mut [NameEntry] { + &mut self.0 + } +} + +impl Deref for Name { + type Target = [NameEntry]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for Name { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Borrow<[NameEntry]> for Name { + fn borrow(&self) -> &[NameEntry] { + &self.0 + } +} +impl BorrowMut<[NameEntry]> for Name { + fn borrow_mut(&mut self) -> &mut [NameEntry] { + &mut self.0 + } +} + #[cfg(test)] mod test { use crate::{ @@ -32,7 +126,7 @@ mod test { Attrs, Error, Path, }; - use super::Name; + use super::NameEntry; use bytes::Bytes; const NAME_VALID: &[u8] = @@ -41,7 +135,7 @@ mod test { #[test] fn encode_success() { encode_decode( - Name { + NameEntry { filename: Path(Bytes::from_static(b"filename")), long_name: Path(Bytes::from_static(b"long name")), attrs: Attrs { @@ -56,7 +150,10 @@ mod test { #[test] fn decode_failure() { for i in 0..NAME_VALID.len() { - assert_eq!(fail_decode::(&NAME_VALID[..i]), Error::NotEnoughData); + assert_eq!( + fail_decode::(&NAME_VALID[..i]), + Error::NotEnoughData + ); } } } From ae0194da54fa071f21b66f4c420b878a6084b8b7 Mon Sep 17 00:00:00 2001 From: Florian Lemaitre Date: Sat, 30 Mar 2024 19:11:15 +0100 Subject: [PATCH 4/4] Fix pipeline --- src/client.rs | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/client.rs b/src/client.rs index f5467c2..2109bbe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -29,22 +29,18 @@ use crate::{message, Message}; /// SFTP client /// -/// ``` +/// ```no_run /// # use std::sync::Arc; /// # use async_trait::async_trait; -/// # tokio_test::block_on(async { /// struct Handler; /// /// #[async_trait] /// impl russh::client::Handler for Handler { /// type Error = russh::Error; -/// async fn check_server_key( -/// &mut self, -/// server_public_key: &russh_keys::key::PublicKey, -/// ) -> Result { -/// Ok(true) -/// } +/// // ... /// } +/// +/// # async fn dummy() -> Result<(), Box> { /// let config = Arc::new(russh::client::Config::default()); /// let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler).await.unwrap(); /// ssh.authenticate_password("user", "pass").await.unwrap(); @@ -52,7 +48,8 @@ use crate::{message, Message}; /// let sftp = rusftp::SftpClient::new(&ssh).await.unwrap(); /// let stat = sftp.stat(rusftp::Stat{path: ".".into()}).await.unwrap(); /// println!("stat '.': {stat:?}"); -/// # }) +/// # Ok(()) +/// # } /// ``` pub struct SftpClient { commands: mpsc::UnboundedSender<(Message, oneshot::Sender)>,