diff --git a/Cargo.lock b/Cargo.lock index 0a9a9e7..dacf6ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -52,6 +52,28 @@ dependencies = [ "subtle", ] +[[package]] +name = "async-stream" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd56dd203fef61ac097dd65721a419ddccb106b2d2b70ba60a6b529f03961a51" +dependencies = [ + "async-stream-impl", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-stream-impl" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "async-trait" version = "0.1.79" @@ -987,10 +1009,13 @@ dependencies = [ name = "rusftp" version = "0.1.0" dependencies = [ + "async-trait", "bytes", "russh", + "russh-keys", "serde", "tokio", + "tokio-test", ] [[package]] @@ -1279,6 +1304,19 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-test" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2468baabc3311435b55dd935f702f42cd1b8abb7e754fb7dfb16bd36aa88f9f7" +dependencies = [ + "async-stream", + "bytes", + "futures-core", + "tokio", + "tokio-stream", +] + [[package]] name = "tokio-util" version = "0.7.10" diff --git a/Cargo.toml b/Cargo.toml index 0f0b8ac..0a6a1e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,12 @@ include = ["**/*.rs", "Cargo.toml", "LICENSE"] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +async-trait = "0.1" bytes = { version = "1.6", features = ["serde"] } -tokio = "1.36" russh = "0.43" serde = "1.0" +tokio = "1.36" + +[dev-dependencies] +russh-keys = "0.43" +tokio-test = "0.4" diff --git a/examples/simple_client.rs b/examples/simple_client.rs new file mode 100644 index 0000000..246923c --- /dev/null +++ b/examples/simple_client.rs @@ -0,0 +1,34 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use rusftp::RealPath; + +struct Handler; + +#[async_trait] +impl russh::client::Handler for Handler { + type Error = russh::Error; + async fn check_server_key( + &mut self, + _server_public_key: &russh_keys::key::PublicKey, + ) -> Result { + Ok(true) + } +} + +#[tokio::main] +pub async fn main() { + let config = Arc::new(russh::client::Config::default()); + let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler) + .await + .unwrap(); + + ssh.authenticate_password("user", "pass").await.unwrap(); + let sftp = rusftp::SftpClient::new(ssh).await.unwrap(); + + let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap(); + println!("CWD: {:?}", cwd); + + let realpath = sftp.realpath(RealPath { path: ".".into() }).await.unwrap(); + println!("RealPath: {:?}", realpath); +} diff --git a/src/client.rs b/src/client.rs index 638341c..2109bbe 100644 --- a/src/client.rs +++ b/src/client.rs @@ -15,15 +15,42 @@ // limitations under the License. use std::collections::HashMap; +use std::future::Future; +use async_trait::async_trait; +use bytes::Buf; use russh::client::Msg; use russh::Channel; use russh::ChannelMsg; -use std::future::Future; use tokio::sync::{mpsc, oneshot}; +use crate::StatusCode; use crate::{message, Message}; +/// SFTP client +/// +/// ```no_run +/// # use std::sync::Arc; +/// # use async_trait::async_trait; +/// struct Handler; +/// +/// #[async_trait] +/// impl russh::client::Handler for Handler { +/// type Error = russh::Error; +/// // ... +/// } +/// +/// # async fn dummy() -> Result<(), Box> { +/// let config = Arc::new(russh::client::Config::default()); +/// let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler).await.unwrap(); +/// ssh.authenticate_password("user", "pass").await.unwrap(); +/// +/// let sftp = rusftp::SftpClient::new(&ssh).await.unwrap(); +/// let stat = sftp.stat(rusftp::Stat{path: ".".into()}).await.unwrap(); +/// println!("stat '.': {stat:?}"); +/// # Ok(()) +/// # } +/// ``` pub struct SftpClient { commands: mpsc::UnboundedSender<(Message, oneshot::Sender)>, } @@ -59,7 +86,10 @@ macro_rules! command { } impl SftpClient { - pub async fn new(mut channel: Channel) -> Result { + pub async fn new(ssh: T) -> Result { + Self::with_channel(ssh.to_sftp_channel().await?).await + } + pub async fn with_channel(mut channel: Channel) -> Result { // Start SFTP subsystem match channel.request_subsystem(false, "sftp").await { Ok(_) => (), @@ -88,41 +118,46 @@ impl SftpClient { })?; // Check handshake response - match channel.wait().await { - Some(ChannelMsg::Data { data }) => { - match Message::decode(data.as_ref()) { - // Valid response: continue - Ok(( - _, - Message::Version(message::Version { - version: 3, - extensions: _, - }), - )) => (), - - // Invalid responses: abort - Ok((_, Message::Version(_))) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Invalid sftp version", - )); - } - Ok(_) => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Bad SFTP init", - )); - } - Err(err) => { - return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); + loop { + match channel.wait().await { + Some(ChannelMsg::Data { data }) => { + match Message::decode(data.as_ref()) { + // Valid response: continue + Ok(( + _, + Message::Version(message::Version { + version: 3, + extensions: _, + }), + )) => break, + + // Invalid responses: abort + Ok((_, Message::Version(_))) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Invalid sftp version", + )); + } + Ok(_) => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Bad SFTP init", + )); + } + Err(err) => { + return Err(std::io::Error::new(std::io::ErrorKind::Other, err)); + } } } - } - _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidData, - "Failed to start SFTP subsystem", - )); + // Unrelated event has been received, looping is required + Some(_) => (), + // Channel has been closed + None => { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "Failed to start SFTP subsystem", + )); + } } } @@ -174,7 +209,21 @@ impl SftpClient { } } Err(err) => { - eprintln!("SFTP Error: Could not decode server frame: {err}"); + if let Some(mut buf) = data.as_ref().get(5..9){ + + let id = buf.get_u32(); + if let Some(tx) = onflight.remove(&id) { + _ = tx.send(Message::Status(crate::Status { + code: StatusCode::BadMessage as u32, + error: err.to_string().into(), + language: "en".into(), + })); + } else { + eprintln!("SFTP Error: Received a reply with an invalid id"); + } + } else { + eprintln!("SFTP Error: Received a bad reply"); + } } } }, @@ -226,3 +275,33 @@ impl SftpClient { command!(readlink: ReadLink -> Name); command!(symlink: Symlink); } + +#[async_trait] +pub trait ToSftpChannel { + async fn to_sftp_channel(self) -> Result, std::io::Error>; +} + +#[async_trait] +impl ToSftpChannel for Channel { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + Ok(self) + } +} + +#[async_trait] +impl ToSftpChannel for &russh::client::Handle { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + match self.channel_open_session().await { + Ok(channel) => Ok(channel), + Err(russh::Error::IO(err)) => Err(err), + Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)), + } + } +} + +#[async_trait] +impl ToSftpChannel for russh::client::Handle { + async fn to_sftp_channel(self) -> Result, std::io::Error> { + (&self).to_sftp_channel().await + } +} diff --git a/src/message/handle.rs b/src/message/handle.rs index 67a001c..159ab72 100644 --- a/src/message/handle.rs +++ b/src/message/handle.rs @@ -14,12 +14,34 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use bytes::Bytes; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] pub struct Handle(pub Bytes); +impl> From for Handle { + fn from(value: T) -> Self { + Self(value.into()) + } +} + +impl Deref for Handle { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl AsRef<[u8]> for Handle { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + #[cfg(test)] mod test { use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID}; diff --git a/src/message/mod.rs b/src/message/mod.rs index 822c406..6ab0a99 100644 --- a/src/message/mod.rs +++ b/src/message/mod.rs @@ -66,7 +66,7 @@ pub use handle::Handle; pub use init::Init; pub use lstat::LStat; pub use mkdir::MkDir; -pub use name::Name; +pub use name::{Name, NameEntry}; pub use open::{pflags, Open}; pub use opendir::OpenDir; pub use path::Path; diff --git a/src/message/name.rs b/src/message/name.rs index 8fbf2a0..14f3482 100644 --- a/src/message/name.rs +++ b/src/message/name.rs @@ -14,17 +14,111 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut, Index, IndexMut}, + slice::SliceIndex, +}; + use serde::{Deserialize, Serialize}; use super::{Attrs, Path}; #[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] -pub struct Name { +pub struct NameEntry { pub filename: Path, pub long_name: Path, pub attrs: Attrs, } +#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)] +pub struct Name(pub Vec); + +impl IntoIterator for Name { + type Item = NameEntry; + + type IntoIter = as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +impl<'a> IntoIterator for &'a Name { + type Item = &'a NameEntry; + + type IntoIter = <&'a Vec as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter() + } +} + +impl<'a> IntoIterator for &'a mut Name { + type Item = &'a mut NameEntry; + + type IntoIter = <&'a mut Vec as IntoIterator>::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.iter_mut() + } +} + +impl FromIterator for Name { + fn from_iter>(iter: T) -> Self { + Self(Vec::from_iter(iter)) + } +} + +impl> Index for Name { + type Output = I::Output; + + fn index(&self, index: I) -> &Self::Output { + &self.0[index] + } +} + +impl> IndexMut for Name { + fn index_mut(&mut self, index: I) -> &mut Self::Output { + &mut self.0[index] + } +} + +impl AsRef<[NameEntry]> for Name { + fn as_ref(&self) -> &[NameEntry] { + &self.0 + } +} +impl AsMut<[NameEntry]> for Name { + fn as_mut(&mut self) -> &mut [NameEntry] { + &mut self.0 + } +} + +impl Deref for Name { + type Target = [NameEntry]; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} +impl DerefMut for Name { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + +impl Borrow<[NameEntry]> for Name { + fn borrow(&self) -> &[NameEntry] { + &self.0 + } +} +impl BorrowMut<[NameEntry]> for Name { + fn borrow_mut(&mut self) -> &mut [NameEntry] { + &mut self.0 + } +} + #[cfg(test)] mod test { use crate::{ @@ -32,7 +126,7 @@ mod test { Attrs, Error, Path, }; - use super::Name; + use super::NameEntry; use bytes::Bytes; const NAME_VALID: &[u8] = @@ -41,7 +135,7 @@ mod test { #[test] fn encode_success() { encode_decode( - Name { + NameEntry { filename: Path(Bytes::from_static(b"filename")), long_name: Path(Bytes::from_static(b"long name")), attrs: Attrs { @@ -56,7 +150,10 @@ mod test { #[test] fn decode_failure() { for i in 0..NAME_VALID.len() { - assert_eq!(fail_decode::(&NAME_VALID[..i]), Error::NotEnoughData); + assert_eq!( + fail_decode::(&NAME_VALID[..i]), + Error::NotEnoughData + ); } } } diff --git a/src/message/path.rs b/src/message/path.rs index d8a9029..dc1ad9e 100644 --- a/src/message/path.rs +++ b/src/message/path.rs @@ -14,6 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::ops::Deref; + use bytes::Bytes; use serde::{Deserialize, Serialize}; @@ -26,6 +28,20 @@ impl> From for Path { } } +impl Deref for Path { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + self.as_ref() + } +} + +impl AsRef<[u8]> for Path { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + #[cfg(test)] mod test { use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID};