Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: Add client example #4

Merged
merged 4 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@ include = ["**/*.rs", "Cargo.toml", "LICENSE"]
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
async-trait = "0.1"
bytes = { version = "1.6", features = ["serde"] }
tokio = "1.36"
russh = "0.43"
serde = "1.0"
tokio = "1.36"

[dev-dependencies]
russh-keys = "0.43"
tokio-test = "0.4"
34 changes: 34 additions & 0 deletions examples/simple_client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::sync::Arc;

use async_trait::async_trait;
use rusftp::RealPath;

struct Handler;

#[async_trait]
impl russh::client::Handler for Handler {
type Error = russh::Error;
async fn check_server_key(
&mut self,
_server_public_key: &russh_keys::key::PublicKey,
) -> Result<bool, Self::Error> {
Ok(true)
}
}

#[tokio::main]
pub async fn main() {
let config = Arc::new(russh::client::Config::default());
let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler)
.await
.unwrap();

ssh.authenticate_password("user", "pass").await.unwrap();
let sftp = rusftp::SftpClient::new(ssh).await.unwrap();

let cwd = sftp.stat(rusftp::Stat { path: ".".into() }).await.unwrap();
println!("CWD: {:?}", cwd);

let realpath = sftp.realpath(RealPath { path: ".".into() }).await.unwrap();
println!("RealPath: {:?}", realpath);
}
151 changes: 115 additions & 36 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,42 @@
// limitations under the License.

use std::collections::HashMap;
use std::future::Future;

use async_trait::async_trait;
use bytes::Buf;
use russh::client::Msg;
use russh::Channel;
use russh::ChannelMsg;
use std::future::Future;
use tokio::sync::{mpsc, oneshot};

use crate::StatusCode;
use crate::{message, Message};

/// SFTP client
///
/// ```no_run
/// # use std::sync::Arc;
/// # use async_trait::async_trait;
/// struct Handler;
///
/// #[async_trait]
/// impl russh::client::Handler for Handler {
/// type Error = russh::Error;
/// // ...
/// }
///
/// # async fn dummy() -> Result<(), Box<dyn std::error::Error>> {
/// let config = Arc::new(russh::client::Config::default());
/// let mut ssh = russh::client::connect(config, ("localhost", 2222), Handler).await.unwrap();
/// ssh.authenticate_password("user", "pass").await.unwrap();
///
/// let sftp = rusftp::SftpClient::new(&ssh).await.unwrap();
/// let stat = sftp.stat(rusftp::Stat{path: ".".into()}).await.unwrap();
/// println!("stat '.': {stat:?}");
/// # Ok(())
/// # }
/// ```
pub struct SftpClient {
commands: mpsc::UnboundedSender<(Message, oneshot::Sender<Message>)>,
}
Expand Down Expand Up @@ -59,7 +86,10 @@ macro_rules! command {
}

impl SftpClient {
pub async fn new(mut channel: Channel<Msg>) -> Result<Self, std::io::Error> {
pub async fn new<T: ToSftpChannel>(ssh: T) -> Result<Self, std::io::Error> {
Self::with_channel(ssh.to_sftp_channel().await?).await
}
pub async fn with_channel(mut channel: Channel<Msg>) -> Result<Self, std::io::Error> {
// Start SFTP subsystem
match channel.request_subsystem(false, "sftp").await {
Ok(_) => (),
Expand Down Expand Up @@ -88,41 +118,46 @@ impl SftpClient {
})?;

// Check handshake response
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(message::Version {
version: 3,
extensions: _,
}),
)) => (),

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Invalid sftp version",
));
}
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Bad SFTP init",
));
}
Err(err) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, err));
loop {
match channel.wait().await {
Some(ChannelMsg::Data { data }) => {
match Message::decode(data.as_ref()) {
// Valid response: continue
Ok((
_,
Message::Version(message::Version {
version: 3,
extensions: _,
}),
)) => break,

// Invalid responses: abort
Ok((_, Message::Version(_))) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Invalid sftp version",
));
}
Ok(_) => {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Bad SFTP init",
));
}
Err(err) => {
return Err(std::io::Error::new(std::io::ErrorKind::Other, err));
}
}
}
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Failed to start SFTP subsystem",
));
// Unrelated event has been received, looping is required
Some(_) => (),
// Channel has been closed
None => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Failed to start SFTP subsystem",
));
}
}
}

Expand Down Expand Up @@ -174,7 +209,21 @@ impl SftpClient {
}
}
Err(err) => {
eprintln!("SFTP Error: Could not decode server frame: {err}");
if let Some(mut buf) = data.as_ref().get(5..9){

let id = buf.get_u32();
if let Some(tx) = onflight.remove(&id) {
_ = tx.send(Message::Status(crate::Status {
code: StatusCode::BadMessage as u32,
error: err.to_string().into(),
language: "en".into(),
}));
} else {
eprintln!("SFTP Error: Received a reply with an invalid id");
}
} else {
eprintln!("SFTP Error: Received a bad reply");
}
}
}
},
Expand Down Expand Up @@ -226,3 +275,33 @@ impl SftpClient {
command!(readlink: ReadLink -> Name);
command!(symlink: Symlink);
}

#[async_trait]
pub trait ToSftpChannel {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error>;
}

#[async_trait]
impl ToSftpChannel for Channel<Msg> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
Ok(self)
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for &russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
match self.channel_open_session().await {
Ok(channel) => Ok(channel),
Err(russh::Error::IO(err)) => Err(err),
Err(err) => Err(std::io::Error::new(std::io::ErrorKind::Other, err)),
}
}
}

#[async_trait]
impl<H: russh::client::Handler> ToSftpChannel for russh::client::Handle<H> {
async fn to_sftp_channel(self) -> Result<Channel<Msg>, std::io::Error> {
(&self).to_sftp_channel().await
}
}
22 changes: 22 additions & 0 deletions src/message/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,34 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::ops::Deref;

use bytes::Bytes;
use serde::{Deserialize, Serialize};

#[derive(Debug, PartialEq, Eq, Clone, Deserialize, Serialize)]
pub struct Handle(pub Bytes);

impl<T: Into<Bytes>> From<T> for Handle {
fn from(value: T) -> Self {
Self(value.into())
}
}

impl Deref for Handle {
type Target = [u8];

fn deref(&self) -> &Self::Target {
self.as_ref()
}
}

impl AsRef<[u8]> for Handle {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}

#[cfg(test)]
mod test {
use crate::message::test_utils::{encode_decode, fail_decode, BYTES_INVALID, BYTES_VALID};
Expand Down
2 changes: 1 addition & 1 deletion src/message/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ pub use handle::Handle;
pub use init::Init;
pub use lstat::LStat;
pub use mkdir::MkDir;
pub use name::Name;
pub use name::{Name, NameEntry};
pub use open::{pflags, Open};
pub use opendir::OpenDir;
pub use path::Path;
Expand Down
Loading