Skip to content

Commit

Permalink
test: Add client example (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
lemaitre-aneo authored Mar 30, 2024
2 parents 48e667c + ae0194d commit de94731
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 42 deletions.
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ include = ["**/*.rs", "Cargo.toml", "LICENSE"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1"
bytes = { version = "1.6", features = ["serde"] }
tokio = "1.36"
russh = "0.43"
serde = "1.0"
tokio = "1.36"

[dev-dependencies]
russh-keys = "0.43"
tokio-test = "0.4"
34 changes: 34 additions & 0 deletions examples/simple_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::sync::Arc;

use async_trait::async_trait;
use rusftp::RealPath;

struct Handler;

#[async_trait]
impl russh::client::Handler for Handler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh_keys::key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
}

#[tokio::main]
pub async fn main() {
let config = Arc::new(russh::client::Config::default());
let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler)
.await
.unwrap();

ssh.authenticate_password("user", "pass").await.unwrap();
let sftp = rusftp::SftpClient::new(ssh).await.unwrap();

let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap();
println!("CWD: {:?}", cwd);

let realpath = sftp.realpath(RealPath { path: ".".into() }).await.unwrap();
println!("RealPath: {:?}", realpath);
}
151 changes: 115 additions & 36 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,42 @@
// limitations under the License.

use std::collections::HashMap;
use std::future::Future;

use async_trait::async_trait;
use bytes::Buf;
use russh::client::Msg;
use russh::Channel;
use russh::ChannelMsg;
use std::future::Future;
use tokio::sync::{mpsc, oneshot};

use crate::StatusCode;
use crate::{message, Message};

/// SFTP client
///
/// ```no_run
/// # use std::sync::Arc;
/// # use async_trait::async_trait;
/// struct Handler;
///
/// #[async_trait]
/// impl russh::client::Handler for Handler {
/// type Error = russh::Error;
/// // ...
/// }
///
/// # async fn dummy() -> Result<(), Box<dyn std::error::Error>> {
/// let config = Arc::new(russh::client::Config::default());
/// let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler).await.unwrap();
/// ssh.authenticate_password("user", "pass").await.unwrap();
///
/// let sftp = rusftp::SftpClient::new(&ssh).await.unwrap();
/// let stat = sftp.stat(rusftp::Stat{path: ".".into()}).await.unwrap();
/// println!("stat '.': {stat:?}");
/// # Ok(())
/// # }
/// ```
pub struct SftpClient {
commands: mpsc::UnboundedSender<(Message, oneshot::Sender<Message>)>,
}
Expand Down Expand Up @@ -59,7 +86,10 @@ macro_rules! command {
}

impl SftpClient {
pub async fn new(mut channel: Channel<Msg>) -> Result<Self, std::io::Error> {
pub async fn new<T: ToSftpChannel>(ssh: T) -> Result<Self, std::io::Error> {
Self::with_channel(ssh.to_sftp_channel().await?).await
}
pub async fn with_channel(mut channel: Channel<Msg>) -> Result<Self, std::io::Error> {
// Start SFTP subsystem
match channel.request_subsystem(false, "sftp").await {
Ok(_) => (),
Expand Down Expand Up @@ -88,41 +118,46 @@ impl SftpClient {
})?;

// Check handshake response
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(message::Version {
version: 3,
extensions: _,
}),
)) => (),

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Invalid sftp version",
));
}
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Bad SFTP init",
));
}
Err(err) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, err));
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(message::Version {
version: 3,
extensions: _,
}),
)) => break,

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Invalid sftp version",
));
}
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Bad SFTP init",
));
}
Err(err) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, err));
}
}
}
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Failed to start SFTP subsystem",
));
// Unrelated event has been received, looping is required
Some(_) => (),
// Channel has been closed
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Failed to start SFTP subsystem",
));
}
}
}

Expand Down Expand Up @@ -174,7 +209,21 @@ impl SftpClient {
}
}
Err(err) => {
eprintln!("SFTP Error: Could not decode server frame: {err}");
if let Some(mut buf) = data.as_ref().get(5..9){

let id = buf.get_u32();
if let Some(tx) = onflight.remove(&id) {
_ = tx.send(Message::Status(crate::Status {
code: StatusCode::BadMessage as u32,
error: err.to_string().into(),
language: "en".into(),
}));
} else {
eprintln!("SFTP Error: Received a reply with an invalid id");
}
} else {
eprintln!("SFTP Error: Received a bad reply");
}
}
}
},
Expand Down Expand Up @@ -226,3 +275,33 @@ impl SftpClient {
command!(readlink: ReadLink -> Name);
command!(symlink: Symlink);
}

#[async_trait]
pub trait ToSftpChannel {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error>;
}

#[async_trait]
impl ToSftpChannel for Channel<Msg> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
Ok(self)
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for &russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
match self.channel_open_session().await {
Ok(channel) => Ok(channel),
Err(russh::Error::IO(err)) => Err(err),
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
}
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
(&self).to_sftp_channel().await
}
}
22 changes: 22 additions & 0 deletions src/message/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Deref;

use bytes::Bytes;
use serde::{Deserialize, Serialize};

#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct Handle(pub Bytes);

impl<T: Into<Bytes>> From<T> for Handle {
fn from(value: T) -> Self {
Self(value.into())
}
}

impl Deref for Handle {
type Target = [u8];

fn deref(&self) -> &Self::Target {
self.as_ref()
}
}

impl AsRef<[u8]> for Handle {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

#[cfg(test)]
mod test {
use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID};
Expand Down
2 changes: 1 addition & 1 deletion src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub use handle::Handle;
pub use init::Init;
pub use lstat::LStat;
pub use mkdir::MkDir;
pub use name::Name;
pub use name::{Name, NameEntry};
pub use open::{pflags, Open};
pub use opendir::OpenDir;
pub use path::Path;
Expand Down
Loading

0 comments on commit de94731

Please sign in to comment.